diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c79ddff680..5cdf16521c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
@@ -274,6 +275,60 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
}
return results
+ @cached(num_args=2,)
+ async def get_linearized_receipts_for_all_rooms(
+ self, to_key: int, from_key: Optional[int] = None
+ ) -> Dict[str, JsonDict]:
+ """Get receipts for all rooms between two stream_ids.
+
+ Args:
+ to_key: Max stream id to fetch receipts upto.
+ from_key: Min stream id to fetch receipts from. None fetches
+ from the start.
+
+ Returns:
+ A dictionary of roomids to a list of receipts.
+ """
+
+ def f(txn):
+ if from_key:
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id > ? AND stream_id <= ?
+ """
+ txn.execute(sql, [from_key, to_key])
+ else:
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id <= ?
+ """
+
+ txn.execute(sql, [to_key])
+
+ return self.db_pool.cursor_to_dict(txn)
+
+ txn_results = await self.db_pool.runInteraction(
+ "get_linearized_receipts_for_all_rooms", f
+ )
+
+ results = {}
+ for row in txn_results:
+ # We want a single event per room, since we want to batch the
+ # receipts by room, event and type.
+ room_event = results.setdefault(
+ row["room_id"],
+ {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+ )
+
+ # The content is of the form:
+ # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
+ event_entry = room_event["content"].setdefault(row["event_id"], {})
+ receipt_type = event_entry.setdefault(row["receipt_type"], {})
+
+ receipt_type[row["user_id"]] = db_to_json(row["data"])
+
+ return results
+
async def get_users_sent_receipts_between(
self, last_id: int, current_id: int
) -> List[str]:
|