summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/databases/main/deviceinbox.py76
1 files changed, 75 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 7c0f953365..554c7a549d 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -24,6 +24,7 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    make_in_list_sql_clause,
 )
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
@@ -136,6 +137,79 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
+    async def get_new_messages(
+        self,
+        user_ids: Collection[str],
+        from_stream_id: int,
+        to_stream_id: int,
+    ) -> Dict[Tuple[str, str], List[JsonDict]]:
+        """
+        Retrieve to-device messages for a given set of user IDs.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
+        Note that a stream ID can be shared by multiple copies of the same message with
+        different recipient devices. Each (device, message_content) tuple has their own
+        row in the device_inbox table.
+
+        Args:
+            user_ids: The users to retrieve to-device messages for.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+
+        Returns:
+            A list of to-device messages.
+        """
+        # Bail out if none of these users have any messages
+        for user_id in user_ids:
+            if self._device_inbox_stream_cache.has_entity_changed(
+                user_id, from_stream_id
+            ):
+                break
+        else:
+            return {}
+
+        def get_new_messages_txn(txn: LoggingTransaction):
+            # Build a query to select messages from any of the given users that are between
+            # the given stream id bounds
+
+            # Scope to only the given users. We need to use this method as doing so is
+            # different across database engines.
+            many_clause_sql, many_clause_args = make_in_list_sql_clause(
+                self.database_engine, "user_id", user_ids
+            )
+
+            sql = f"""
+                SELECT user_id, device_id, message_json FROM device_inbox
+                WHERE {many_clause_sql}
+                AND ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+            """
+
+            txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id))
+
+            # Create a dictionary of (user ID, device ID) -> list of messages that
+            # that device is meant to receive.
+            recipient_user_id_device_id_to_messages: Dict[
+                Tuple[str, str], List[JsonDict]
+            ] = {}
+
+            for row in txn:
+                recipient_user_id = row[0]
+                recipient_device_id = row[1]
+                message_dict = db_to_json(row[2])
+
+                recipient_user_id_device_id_to_messages.setdefault(
+                    (recipient_user_id, recipient_device_id), []
+                ).append(message_dict)
+
+            return recipient_user_id_device_id_to_messages
+
+        return await self.db_pool.runInteraction(
+            "get_new_messages", get_new_messages_txn
+        )
+
     async def get_new_messages_for_device(
         self,
         user_id: str,