summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py38
1 files changed, 24 insertions, 14 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 6b8668d2dc..f4cdc2e399 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -559,7 +559,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
     def _get_receipts_by_room_txn(
         self, txn: LoggingTransaction, user_id: str
-    ) -> List[Tuple[str, int]]:
+    ) -> Dict[str, int]:
+        """
+        Generate a map of room ID to the latest stream ordering that has been
+        read by the given user.
+
+        Args:
+            txn:
+            user_id: The user to fetch receipts for.
+
+        Returns:
+            A map of room ID to stream ordering for all rooms the user has a receipt in.
+        """
         receipt_types_clause, args = make_in_list_sql_clause(
             self.database_engine,
             "receipt_type",
@@ -580,7 +591,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         args.extend((user_id,))
         txn.execute(sql, args)
-        return cast(List[Tuple[str, int]], txn.fetchall())
+        return {
+            room_id: latest_stream_ordering
+            for room_id, latest_stream_ordering in txn.fetchall()
+        }
 
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
@@ -605,12 +619,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        receipts_by_room = dict(
-            await self.db_pool.runInteraction(
-                "get_unread_push_actions_for_user_in_range_http_receipts",
-                self._get_receipts_by_room_txn,
-                user_id=user_id,
-            ),
+        receipts_by_room = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_http_receipts",
+            self._get_receipts_by_room_txn,
+            user_id=user_id,
         )
 
         def get_push_actions_txn(
@@ -679,12 +691,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        receipts_by_room = dict(
-            await self.db_pool.runInteraction(
-                "get_unread_push_actions_for_user_in_range_email_receipts",
-                self._get_receipts_by_room_txn,
-                user_id=user_id,
-            ),
+        receipts_by_room = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_email_receipts",
+            self._get_receipts_by_room_txn,
+            user_id=user_id,
         )
 
         def get_push_actions_txn(