summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-09-26 14:28:12 -0400
committerGitHub <noreply@github.com>2022-09-26 18:28:12 +0000
commit2fae1a3f7862bf38cd0b52dfd3ea3ae76794d2b7 (patch)
tree38469c5dc481a93a4c14d342ff53fd0a8b7680c3 /synapse/storage
parentUpdate the manpage documentation for the hash_password script (#13911) (diff)
downloadsynapse-2fae1a3f7862bf38cd0b52dfd3ea3ae76794d2b7.tar.xz
Improve tests for get_unread_push_actions_for_user_in_range_*. (#13893)
* Adds a docstring.
* Reduces a small amount of duplicated code.
* Improves tests.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py38
1 files changed, 24 insertions, 14 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py

index 6b8668d2dc..f4cdc2e399 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -559,7 +559,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _get_receipts_by_room_txn( self, txn: LoggingTransaction, user_id: str - ) -> List[Tuple[str, int]]: + ) -> Dict[str, int]: + """ + Generate a map of room ID to the latest stream ordering that has been + read by the given user. + + Args: + txn: + user_id: The user to fetch receipts for. + + Returns: + A map of room ID to stream ordering for all rooms the user has a receipt in. + """ receipt_types_clause, args = make_in_list_sql_clause( self.database_engine, "receipt_type", @@ -580,7 +591,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas args.extend((user_id,)) txn.execute(sql, args) - return cast(List[Tuple[str, int]], txn.fetchall()) + return { + room_id: latest_stream_ordering + for room_id, latest_stream_ordering in txn.fetchall() + } async def get_unread_push_actions_for_user_in_range_for_http( self, @@ -605,12 +619,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = dict( - await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ), + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, ) def get_push_actions_txn( @@ -679,12 +691,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas The list will have between 0~limit entries. """ - receipts_by_room = dict( - await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_receipts", - self._get_receipts_by_room_txn, - user_id=user_id, - ), + receipts_by_room = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, ) def get_push_actions_txn(