summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-04-05 12:46:34 +0100
committerGitHub <noreply@github.com>2024-04-05 12:46:34 +0100
commit5360baeb6439366c29d55038da7f677c64eea4bf (patch)
tree964ceb5610ce4fb1ebf8aed44faa15b406b4d6d5 /synapse/storage/databases
parentFix bug in calculating state for non-gappy syncs (#16942) (diff)
downloadsynapse-5360baeb6439366c29d55038da7f677c64eea4bf.tar.xz
Pull out fewer receipts from DB when doing push (#17049)
Before we were pulling out *all* read receipts for a user for every
event we pushed. Instead let's only pull out the relevant receipts.

This also pulled out the event rows for each receipt, causing load on
the events table.
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py124
1 files changed, 102 insertions, 22 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a5666cd9b..40bf000e9c 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -106,7 +106,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.stream import StreamWorkerStore
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -859,37 +859,86 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
 
-    def _get_receipts_by_room_txn(
-        self, txn: LoggingTransaction, user_id: str
+    def _get_receipts_for_room_and_threads_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        room_ids: StrCollection,
+        thread_ids: StrCollection,
     ) -> Dict[str, _RoomReceipt]:
         """
-        Generate a map of room ID to the latest stream ordering that has been
-        read by the given user.
+        Get (private) read receipts for a user in each of the given room IDs
+        and thread IDs.
 
-        Args:
-            txn:
-            user_id: The user to fetch receipts for.
+        Note: The corresponding room ID for each thread must appear in
+        `room_ids` arg.
 
         Returns:
             A map including all rooms the user is in with a receipt. It maps
             room IDs to _RoomReceipt instances
         """
-        receipt_types_clause, args = make_in_list_sql_clause(
+
+        receipt_types_clause, receipts_args = make_in_list_sql_clause(
             self.database_engine,
             "receipt_type",
             (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
         )
 
+        thread_ids_clause, thread_ids_args = make_in_list_sql_clause(
+            self.database_engine,
+            "thread_id",
+            thread_ids,
+        )
+
+        room_ids_clause, room_ids_args = make_in_list_sql_clause(
+            self.database_engine,
+            "room_id",
+            room_ids,
+        )
+
+        # We use the union of two (almost identical) queries here, the first to
+        # fetch the specific thread receipts and the second to fetch the
+        # unthreaded receipts.
+        #
+        # This SQL is optimized to use the indices we have on
+        # `receipts_linearized`.
+        #
+        # We compare room ID and thread IDs independently due to the above,
+        # which means that this query might return more rows than we need if the
+        # same thread ID appears across different rooms (e.g. 'main' thread ID).
+        # This doesn't cause any logic issues, and isn't a performance concern
+        # given this function generally gets called with only one room and
+        # thread ID.
         sql = f"""
             SELECT room_id, thread_id, MAX(stream_ordering)
             FROM receipts_linearized
             INNER JOIN events USING (room_id, event_id)
             WHERE {receipt_types_clause}
+                AND {thread_ids_clause}
+                AND {room_ids_clause}
+            AND user_id = ?
+            GROUP BY room_id, thread_id
+
+        UNION ALL
+
+            SELECT room_id, thread_id, MAX(stream_ordering)
+            FROM receipts_linearized
+            INNER JOIN events USING (room_id, event_id)
+            WHERE {receipt_types_clause}
+                AND {room_ids_clause}
+                AND thread_id IS NULL
             AND user_id = ?
             GROUP BY room_id, thread_id
         """
 
-        args.extend((user_id,))
+        args = list(receipts_args)
+        args.extend(thread_ids_args)
+        args.extend(room_ids_args)
+        args.append(user_id)
+        args.extend(receipts_args)
+        args.extend(room_ids_args)
+        args.append(user_id)
+
         txn.execute(sql, args)
 
         result: Dict[str, _RoomReceipt] = {}
@@ -925,12 +974,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        receipts_by_room = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_receipts",
-            self._get_receipts_by_room_txn,
-            user_id=user_id,
-        )
-
         def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, str, int, str, bool]]:
@@ -952,6 +995,27 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
         )
 
+        room_ids = set()
+        thread_ids = []
+        for (
+            _,
+            room_id,
+            thread_id,
+            _,
+            _,
+            _,
+        ) in push_actions:
+            room_ids.add(room_id)
+            thread_ids.append(thread_id)
+
+        receipts_by_room = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_http_receipts",
+            self._get_receipts_for_room_and_threads_txn,
+            user_id=user_id,
+            room_ids=room_ids,
+            thread_ids=thread_ids,
+        )
+
         notifs = [
             HttpPushAction(
                 event_id=event_id,
@@ -998,12 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        receipts_by_room = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_receipts",
-            self._get_receipts_by_room_txn,
-            user_id=user_id,
-        )
-
         def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, str, int, str, bool, int]]:
@@ -1026,6 +1084,28 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
         )
 
+        room_ids = set()
+        thread_ids = []
+        for (
+            _,
+            room_id,
+            thread_id,
+            _,
+            _,
+            _,
+            _,
+        ) in push_actions:
+            room_ids.add(room_id)
+            thread_ids.append(thread_id)
+
+        receipts_by_room = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_email_receipts",
+            self._get_receipts_for_room_and_threads_txn,
+            user_id=user_id,
+            room_ids=room_ids,
+            thread_ids=thread_ids,
+        )
+
         # Make a list of dicts from the two sets of results.
         notifs = [
             EmailPushAction(