diff --git a/changelog.d/17049.misc b/changelog.d/17049.misc
new file mode 100644
index 0000000000..f71a6473a2
--- /dev/null
+++ b/changelog.d/17049.misc
@@ -0,0 +1 @@
+Improve database performance by reducing number of receipts fetched when sending push notifications.
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a5666cd9b..40bf000e9c 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -106,7 +106,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -859,37 +859,86 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
- def _get_receipts_by_room_txn(
- self, txn: LoggingTransaction, user_id: str
+ def _get_receipts_for_room_and_threads_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ room_ids: StrCollection,
+ thread_ids: StrCollection,
) -> Dict[str, _RoomReceipt]:
"""
- Generate a map of room ID to the latest stream ordering that has been
- read by the given user.
+ Get (private) read receipts for a user in each of the given room IDs
+ and thread IDs.
- Args:
- txn:
- user_id: The user to fetch receipts for.
+ Note: The corresponding room ID for each thread must appear in
+ `room_ids` arg.
Returns:
A map including all rooms the user is in with a receipt. It maps
room IDs to _RoomReceipt instances
"""
- receipt_types_clause, args = make_in_list_sql_clause(
+
+ receipt_types_clause, receipts_args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
+ thread_ids_clause, thread_ids_args = make_in_list_sql_clause(
+ self.database_engine,
+ "thread_id",
+ thread_ids,
+ )
+
+ room_ids_clause, room_ids_args = make_in_list_sql_clause(
+ self.database_engine,
+ "room_id",
+ room_ids,
+ )
+
+ # We use the union of two (almost identical) queries here, the first to
+ # fetch the specific thread receipts and the second to fetch the
+ # unthreaded receipts.
+ #
+ # This SQL is optimized to use the indices we have on
+ # `receipts_linearized`.
+ #
+ # We compare room ID and thread IDs independently due to the above,
+ # which means that this query might return more rows than we need if the
+ # same thread ID appears across different rooms (e.g. 'main' thread ID).
+ # This doesn't cause any logic issues, and isn't a performance concern
+ # given this function generally gets called with only one room and
+ # thread ID.
sql = f"""
SELECT room_id, thread_id, MAX(stream_ordering)
FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause}
+ AND {thread_ids_clause}
+ AND {room_ids_clause}
+ AND user_id = ?
+ GROUP BY room_id, thread_id
+
+ UNION ALL
+
+ SELECT room_id, thread_id, MAX(stream_ordering)
+ FROM receipts_linearized
+ INNER JOIN events USING (room_id, event_id)
+ WHERE {receipt_types_clause}
+ AND {room_ids_clause}
+ AND thread_id IS NULL
AND user_id = ?
GROUP BY room_id, thread_id
"""
- args.extend((user_id,))
+ args = list(receipts_args)
+ args.extend(thread_ids_args)
+ args.extend(room_ids_args)
+ args.append(user_id)
+ args.extend(receipts_args)
+ args.extend(room_ids_args)
+ args.append(user_id)
+
txn.execute(sql, args)
result: Dict[str, _RoomReceipt] = {}
@@ -925,12 +974,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
- receipts_by_room = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_http_receipts",
- self._get_receipts_by_room_txn,
- user_id=user_id,
- )
-
def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, str, bool]]:
@@ -952,6 +995,27 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
"get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
)
+ room_ids = set()
+ thread_ids = []
+ for (
+ _,
+ room_id,
+ thread_id,
+ _,
+ _,
+ _,
+ ) in push_actions:
+ room_ids.add(room_id)
+ thread_ids.append(thread_id)
+
+ receipts_by_room = await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_http_receipts",
+ self._get_receipts_for_room_and_threads_txn,
+ user_id=user_id,
+ room_ids=room_ids,
+ thread_ids=thread_ids,
+ )
+
notifs = [
HttpPushAction(
event_id=event_id,
@@ -998,12 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
- receipts_by_room = await self.db_pool.runInteraction(
- "get_unread_push_actions_for_user_in_range_email_receipts",
- self._get_receipts_by_room_txn,
- user_id=user_id,
- )
-
def get_push_actions_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, str, bool, int]]:
@@ -1026,6 +1084,28 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
"get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
)
+ room_ids = set()
+ thread_ids = []
+ for (
+ _,
+ room_id,
+ thread_id,
+ _,
+ _,
+ _,
+ _,
+ ) in push_actions:
+ room_ids.add(room_id)
+ thread_ids.append(thread_id)
+
+ receipts_by_room = await self.db_pool.runInteraction(
+ "get_unread_push_actions_for_user_in_range_email_receipts",
+ self._get_receipts_for_room_and_threads_txn,
+ user_id=user_id,
+ room_ids=room_ids,
+ thread_ids=thread_ids,
+ )
+
# Make a list of dicts from the two sets of results.
notifs = [
EmailPushAction(
|