diff options
-rw-r--r-- | synapse/storage/databases/main/deviceinbox.py | 76 |
1 files changed, 75 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 7c0f953365..554c7a549d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -24,6 +24,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -136,6 +137,79 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() + async def get_new_messages( + self, + user_ids: Collection[str], + from_stream_id: int, + to_stream_id: int, + ) -> Dict[Tuple[str, str], List[JsonDict]]: + """ + Retrieve to-device messages for a given set of user IDs. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + + Note that a stream ID can be shared by multiple copies of the same message with + different recipient devices. Each (device, message_content) tuple has their own + row in the device_inbox table. + + Args: + user_ids: The users to retrieve to-device messages for. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + + Returns: + A list of to-device messages. + """ + # Bail out if none of these users have any messages + for user_id in user_ids: + if self._device_inbox_stream_cache.has_entity_changed( + user_id, from_stream_id + ): + break + else: + return {} + + def get_new_messages_txn(txn: LoggingTransaction): + # Build a query to select messages from any of the given users that are between + # the given stream id bounds + + # Scope to only the given users. We need to use this method as doing so is + # different across database engines. + many_clause_sql, many_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids + ) + + sql = f""" + SELECT user_id, device_id, message_json FROM device_inbox + WHERE {many_clause_sql} + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + + txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id)) + + # Create a dictionary of (user ID, device ID) -> list of messages that + # that device is meant to receive. + recipient_user_id_device_id_to_messages: Dict[ + Tuple[str, str], List[JsonDict] + ] = {} + + for row in txn: + recipient_user_id = row[0] + recipient_device_id = row[1] + message_dict = db_to_json(row[2]) + + recipient_user_id_device_id_to_messages.setdefault( + (recipient_user_id, recipient_device_id), [] + ).append(message_dict) + + return recipient_user_id_device_id_to_messages + + return await self.db_pool.runInteraction( + "get_new_messages", get_new_messages_txn + ) + async def get_new_messages_for_device( self, user_id: str, |