summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py124
1 files changed, 102 insertions, 22 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a5666cd9b..40bf000e9c 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -106,7 +106,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.stream import StreamWorkerStore
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -859,37 +859,86 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
 
-    def _get_receipts_by_room_txn(
-        self, txn: LoggingTransaction, user_id: str
+    def _get_receipts_for_room_and_threads_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        room_ids: StrCollection,
+        thread_ids: StrCollection,
     ) -> Dict[str, _RoomReceipt]:
         """
-        Generate a map of room ID to the latest stream ordering that has been
-        read by the given user.
+        Get (private) read receipts for a user in each of the given room IDs
+        and thread IDs.
 
-        Args:
-            txn:
-            user_id: The user to fetch receipts for.
+        Note: The corresponding room ID for each thread must appear in
+        `room_ids` arg.
 
         Returns:
             A map including all rooms the user is in with a receipt. It maps
             room IDs to _RoomReceipt instances
         """
-        receipt_types_clause, args = make_in_list_sql_clause(
+
+        receipt_types_clause, receipts_args = make_in_list_sql_clause(
             self.database_engine,
             "receipt_type",
             (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
         )
 
+        thread_ids_clause, thread_ids_args = make_in_list_sql_clause(
+            self.database_engine,
+            "thread_id",
+            thread_ids,
+        )
+
+        room_ids_clause, room_ids_args = make_in_list_sql_clause(
+            self.database_engine,
+            "room_id",
+            room_ids,
+        )
+
+        # We use the union of two (almost identical) queries here, the first to
+        # fetch the specific thread receipts and the second to fetch the
+        # unthreaded receipts.
+        #
+        # This SQL is optimized to use the indices we have on
+        # `receipts_linearized`.
+        #
+        # We compare room ID and thread IDs independently due to the above,
+        # which means that this query might return more rows than we need if the
+        # same thread ID appears across different rooms (e.g. 'main' thread ID).
+        # This doesn't cause any logic issues, and isn't a performance concern
+        # given this function generally gets called with only one room and
+        # thread ID.
         sql = f"""
             SELECT room_id, thread_id, MAX(stream_ordering)
             FROM receipts_linearized
             INNER JOIN events USING (room_id, event_id)
             WHERE {receipt_types_clause}
+                AND {thread_ids_clause}
+                AND {room_ids_clause}
+            AND user_id = ?
+            GROUP BY room_id, thread_id
+
+        UNION ALL
+
+            SELECT room_id, thread_id, MAX(stream_ordering)
+            FROM receipts_linearized
+            INNER JOIN events USING (room_id, event_id)
+            WHERE {receipt_types_clause}
+                AND {room_ids_clause}
+                AND thread_id IS NULL
             AND user_id = ?
             GROUP BY room_id, thread_id
         """
 
-        args.extend((user_id,))
+        args = list(receipts_args)
+        args.extend(thread_ids_args)
+        args.extend(room_ids_args)
+        args.append(user_id)
+        args.extend(receipts_args)
+        args.extend(room_ids_args)
+        args.append(user_id)
+
         txn.execute(sql, args)
 
         result: Dict[str, _RoomReceipt] = {}
@@ -925,12 +974,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        receipts_by_room = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_receipts",
-            self._get_receipts_by_room_txn,
-            user_id=user_id,
-        )
-
         def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, str, int, str, bool]]:
@@ -952,6 +995,27 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
         )
 
+        room_ids = set()
+        thread_ids = []
+        for (
+            _,
+            room_id,
+            thread_id,
+            _,
+            _,
+            _,
+        ) in push_actions:
+            room_ids.add(room_id)
+            thread_ids.append(thread_id)
+
+        receipts_by_room = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_http_receipts",
+            self._get_receipts_for_room_and_threads_txn,
+            user_id=user_id,
+            room_ids=room_ids,
+            thread_ids=thread_ids,
+        )
+
         notifs = [
             HttpPushAction(
                 event_id=event_id,
@@ -998,12 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        receipts_by_room = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_receipts",
-            self._get_receipts_by_room_txn,
-            user_id=user_id,
-        )
-
         def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, str, int, str, bool, int]]:
@@ -1026,6 +1084,28 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
         )
 
+        room_ids = set()
+        thread_ids = []
+        for (
+            _,
+            room_id,
+            thread_id,
+            _,
+            _,
+            _,
+            _,
+        ) in push_actions:
+            room_ids.add(room_id)
+            thread_ids.append(thread_id)
+
+        receipts_by_room = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_email_receipts",
+            self._get_receipts_for_room_and_threads_txn,
+            user_id=user_id,
+            room_ids=room_ids,
+            thread_ids=thread_ids,
+        )
+
         # Make a list of dicts from the two sets of results.
         notifs = [
             EmailPushAction(