diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/databases/main/event_push_actions.py | 124 |
1 files changed, 102 insertions, 22 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3a5666cd9b..40bf000e9c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -106,7 +106,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore -from synapse.types import JsonDict +from synapse.types import JsonDict, StrCollection from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -859,37 +859,86 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas return await self.db_pool.runInteraction("get_push_action_users_in_range", f) - def _get_receipts_by_room_txn( - self, txn: LoggingTransaction, user_id: str + def _get_receipts_for_room_and_threads_txn( + self, + txn: LoggingTransaction, + user_id: str, + room_ids: StrCollection, + thread_ids: StrCollection, ) -> Dict[str, _RoomReceipt]: """ - Generate a map of room ID to the latest stream ordering that has been - read by the given user. + Get (private) read receipts for a user in each of the given room IDs + and thread IDs. - Args: - txn: - user_id: The user to fetch receipts for. + Note: The corresponding room ID for each thread must appear in + `room_ids` arg. Returns: A map including all rooms the user is in with a receipt. It maps room IDs to _RoomReceipt instances """ - receipt_types_clause, args = make_in_list_sql_clause( + + receipt_types_clause, receipts_args = make_in_list_sql_clause( self.database_engine, "receipt_type", (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) + thread_ids_clause, thread_ids_args = make_in_list_sql_clause( + self.database_engine, + "thread_id", + thread_ids, + ) + + room_ids_clause, room_ids_args = make_in_list_sql_clause( + self.database_engine, + "room_id", + room_ids, + ) + + # We use the union of two (almost identical) queries here, the first to + # fetch the specific thread receipts and the second to fetch the + # unthreaded receipts. + # + # This SQL is optimized to use the indices we have on + # `receipts_linearized`. + # + # We compare room ID and thread IDs independently due to the above, + # which means that this query might return more rows than we need if the + # same thread ID appears across different rooms (e.g. 'main' thread ID). + # This doesn't cause any logic issues, and isn't a performance concern + # given this function generally gets called with only one room and + # thread ID. sql = f""" SELECT room_id, thread_id, MAX(stream_ordering) FROM receipts_linearized INNER JOIN events USING (room_id, event_id) WHERE {receipt_types_clause} + AND {thread_ids_clause} + AND {room_ids_clause} + AND user_id = ? + GROUP BY room_id, thread_id + + UNION ALL + + SELECT room_id, thread_id, MAX(stream_ordering) + FROM receipts_linearized + INNER JOIN events USING (room_id, event_id) + WHERE {receipt_types_clause} + AND {room_ids_clause} + AND thread_id IS NULL AND user_id = ? GROUP BY room_id, thread_id """ - args.extend((user_id,)) + args = list(receipts_args) + args.extend(thread_ids_args) + args.extend(room_ids_args) + args.append(user_id) + args.extend(receipts_args) + args.extend(room_ids_args) + args.append(user_id) + txn.execute(sql, args) result: Dict[str, _RoomReceipt] = {} @@ -925,12 +974,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ) - def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, str, bool]]: @@ -952,6 +995,27 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn ) + room_ids = set() + thread_ids = [] + for ( + _, + room_id, + thread_id, + _, + _, + _, + ) in push_actions: + room_ids.add(room_id) + thread_ids.append(thread_id) + + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_receipts", + self._get_receipts_for_room_and_threads_txn, + user_id=user_id, + room_ids=room_ids, + thread_ids=thread_ids, + ) + notifs = [ HttpPushAction( event_id=event_id, @@ -998,12 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ) - def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, str, int, str, bool, int]]: @@ -1026,6 +1084,28 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn ) + room_ids = set() + thread_ids = [] + for ( + _, + room_id, + thread_id, + _, + _, + _, + _, + ) in push_actions: + room_ids.add(room_id) + thread_ids.append(thread_id) + + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_receipts", + self._get_receipts_for_room_and_threads_txn, + user_id=user_id, + room_ids=room_ids, + thread_ids=thread_ids, + ) + # Make a list of dicts from the two sets of results. notifs = [ EmailPushAction( |