diff options
author | Andrew Morgan <andrew@amorgan.xyz> | 2021-11-05 15:59:08 +0000 |
---|---|---|
committer | Andrew Morgan <andrew@amorgan.xyz> | 2021-11-16 12:59:17 +0000 |
commit | 7899f823ae6f6987811e027b3c74104197cf2482 (patch) | |
tree | 044b9e4b1cd87a1a021767b4d604baed95e662e5 | |
parent | Allow setting/getting stream id per appservice for to-device messages (diff) | |
download | synapse-7899f823ae6f6987811e027b3c74104197cf2482.tar.xz |
Add database method to fetch to-device messages by user_ids from db
This method is quite similar to the one below, except that it doesn't support device ids, and supports querying with more than one user id, both of which are relevant to application services. The results are also formatted in a different data structure, so I'm not sure how much we could really share here between the two methods.
-rw-r--r-- | synapse/storage/databases/main/deviceinbox.py | 76 |
1 files changed, 75 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 7c0f953365..554c7a549d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -24,6 +24,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -136,6 +137,79 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() + async def get_new_messages( + self, + user_ids: Collection[str], + from_stream_id: int, + to_stream_id: int, + ) -> Dict[Tuple[str, str], List[JsonDict]]: + """ + Retrieve to-device messages for a given set of user IDs. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + + Note that a stream ID can be shared by multiple copies of the same message with + different recipient devices. Each (device, message_content) tuple has their own + row in the device_inbox table. + + Args: + user_ids: The users to retrieve to-device messages for. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + + Returns: + A list of to-device messages. + """ + # Bail out if none of these users have any messages + for user_id in user_ids: + if self._device_inbox_stream_cache.has_entity_changed( + user_id, from_stream_id + ): + break + else: + return {} + + def get_new_messages_txn(txn: LoggingTransaction): + # Build a query to select messages from any of the given users that are between + # the given stream id bounds + + # Scope to only the given users. We need to use this method as doing so is + # different across database engines. + many_clause_sql, many_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + + sql = f""" + SELECT user_id, device_id, message_json FROM device_inbox + WHERE {many_clause_sql} + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + + txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id)) + + # Create a dictionary of (user ID, device ID) -> list of messages that + # that device is meant to receive. + recipient_user_id_device_id_to_messages: Dict[ + Tuple[str, str], List[JsonDict] + ] = {} + + for row in txn: + recipient_user_id = row[0] + recipient_device_id = row[1] + message_dict = db_to_json(row[2]) + + recipient_user_id_device_id_to_messages.setdefault( + (recipient_user_id, recipient_device_id), [] + ).append(message_dict) + + return recipient_user_id_device_id_to_messages + + return await self.db_pool.runInteraction( + "get_new_messages", get_new_messages_txn + ) + async def get_new_messages_for_device( self, user_id: str, |