summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/deviceinbox.py149
1 files changed, 72 insertions, 77 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 40477b9da0..fa47b471e8 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -245,33 +245,74 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 * The last-processed stream ID. Subsequent calls of this function with the
                   same device should pass this value as 'from_stream_id'.
         """
-        (
-            user_id_device_id_to_messages,
-            last_processed_stream_id,
-        ) = await self._get_device_messages(
-            user_ids=[user_id],
-            device_id=device_id,
-            from_stream_id=from_stream_id,
-            to_stream_id=to_stream_id,
-            limit=limit,
-        )
-
-        if not user_id_device_id_to_messages:
+        if not self._device_inbox_stream_cache.has_entity_changed(
+            user_id, from_stream_id
+        ):
             # There were no messages!
             return [], to_stream_id
 
-        # Extract the messages, no need to return the user and device ID again
-        to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+        def get_device_messages_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
+            sql = """
+                SELECT stream_id, message_json FROM device_inbox
+                WHERE user_id = ? AND device_id = ?
+                AND ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (user_id, device_id, from_stream_id, to_stream_id, limit))
+
+            # Create and fill a dictionary of (user ID, device ID) -> list of messages
+            # intended for each device.
+            last_processed_stream_pos = to_stream_id
+            to_device_messages: List[JsonDict] = []
+            rowcount = 0
+            for row in txn:
+                rowcount += 1
+
+                last_processed_stream_pos = row[0]
+                message_dict = db_to_json(row[1])
+
+                # Store the device details
+                to_device_messages.append(message_dict)
 
-        return to_device_messages, last_processed_stream_id
+                # start a new span for each message, so that we can tag each separately
+                with start_active_span("get_to_device_message"):
+                    set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"])
+                    set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"])
+                    set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
+                    set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
+                    set_tag(
+                        SynapseTags.TO_DEVICE_MSGID,
+                        message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
+                    )
+
+            if rowcount == limit:
+                # We ended up bumping up against the message limit. There may be more messages
+                # to retrieve. Return what we have, as well as the last stream position that
+                # was processed.
+                #
+                # The caller is expected to set this as the lower (exclusive) bound
+                # for the next query of this device.
+                return to_device_messages, last_processed_stream_pos
+
+            # The limit was not reached, thus we know that recipient_device_to_messages
+            # contains all to-device messages for the given device and stream id range.
+            #
+            # We return to_stream_id, which the caller should then provide as the lower
+            # (exclusive) bound on the next query of this device.
+            return to_device_messages, to_stream_id
+
+        return await self.db_pool.runInteraction(
+            "get_messages_for_device", get_device_messages_txn
+        )
 
     async def _get_device_messages(
         self,
         user_ids: Collection[str],
         from_stream_id: int,
         to_stream_id: int,
-        device_id: Optional[str] = None,
-        limit: Optional[int] = None,
     ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
         """
         Retrieve pending to-device messages for a collection of user devices.
@@ -291,11 +332,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             user_ids: The user IDs to filter device messages by.
             from_stream_id: The lower boundary of stream id to filter with (exclusive).
             to_stream_id: The upper boundary of stream id to filter with (inclusive).
-            device_id: A device ID to query to-device messages for. If not provided, to-device
-                messages from all device IDs for the given user IDs will be queried. May not be
-                provided if `user_ids` contains more than one entry.
-            limit: The maximum number of to-device messages to return. Can only be used when
-                passing a single user ID / device ID tuple.
+
 
         Returns:
             A tuple containing:
@@ -308,30 +345,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             logger.warning("No users provided upon querying for device IDs")
             return {}, to_stream_id
 
-        # Prevent a query for one user's device also retrieving another user's device with
-        # the same device ID (device IDs are not unique across users).
-        if len(user_ids) > 1 and device_id is not None:
-            raise AssertionError(
-                "Programming error: 'device_id' cannot be supplied to "
-                "_get_device_messages when >1 user_id has been provided"
-            )
-
-        # A limit can only be applied when querying for a single user ID / device ID tuple.
-        # See the docstring of this function for more details.
-        if limit is not None and device_id is None:
-            raise AssertionError(
-                "Programming error: _get_device_messages was passed 'limit' "
-                "without a specific user_id/device_id"
-            )
-
         user_ids_to_query: Set[str] = set()
-        device_ids_to_query: Set[str] = set()
-
-        # Note that a device ID could be an empty str
-        if device_id is not None:
-            # If a device ID was passed, use it to filter results.
-            # Otherwise, device IDs will be derived from the given collection of user IDs.
-            device_ids_to_query.add(device_id)
 
         # Determine which users have devices with pending messages
         for user_id in user_ids:
@@ -355,20 +369,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             # hidden devices should not receive to-device messages.
             # Note that this is more efficient than just dropping `device_id` from the query,
             # since device_inbox has an index on `(user_id, device_id, stream_id)`
-            if not device_ids_to_query:
-                user_device_dicts = cast(
-                    List[Tuple[str]],
-                    self.db_pool.simple_select_many_txn(
-                        txn,
-                        table="devices",
-                        column="user_id",
-                        iterable=user_ids_to_query,
-                        keyvalues={"hidden": False},
-                        retcols=("device_id",),
-                    ),
-                )
 
-                device_ids_to_query.update({row[0] for row in user_device_dicts})
+            user_device_dicts = cast(
+                List[Tuple[str]],
+                self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="devices",
+                    column="user_id",
+                    iterable=user_ids_to_query,
+                    keyvalues={"hidden": False},
+                    retcols=("device_id",),
+                ),
+            )
+
+            device_ids_to_query = {row[0] for row in user_device_dicts}
 
             if not device_ids_to_query:
                 # We've ended up with no devices to query.
@@ -400,22 +414,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 to_stream_id,
             )
 
-            # If a limit was provided, limit the data retrieved from the database
-            if limit is not None:
-                sql += "LIMIT ?"
-                sql_args += (limit,)
-
             txn.execute(sql, sql_args)
 
             # Create and fill a dictionary of (user ID, device ID) -> list of messages
             # intended for each device.
-            last_processed_stream_pos = to_stream_id
             recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
             rowcount = 0
             for row in txn:
                 rowcount += 1
 
-                last_processed_stream_pos = row[0]
                 recipient_user_id = row[1]
                 recipient_device_id = row[2]
                 message_dict = db_to_json(row[3])
@@ -436,18 +443,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                         message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
                     )
 
-            if limit is not None and rowcount == limit:
-                # We ended up bumping up against the message limit. There may be more messages
-                # to retrieve. Return what we have, as well as the last stream position that
-                # was processed.
-                #
-                # The caller is expected to set this as the lower (exclusive) bound
-                # for the next query of this device.
-                return recipient_device_to_messages, last_processed_stream_pos
-
-            # The limit was not reached, thus we know that recipient_device_to_messages
-            # contains all to-device messages for the given device and stream id range.
-            #
             # We return to_stream_id, which the caller should then provide as the lower
             # (exclusive) bound on the next query of this device.
             return recipient_device_to_messages, to_stream_id