summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/appservice.py24
-rw-r--r--synapse/storage/databases/main/deviceinbox.py276
2 files changed, 254 insertions, 46 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 2bb5288431..304814af5d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore(
         service: ApplicationService,
         events: List[EventBase],
         ephemeral: List[JsonDict],
+        to_device_messages: List[JsonDict],
     ) -> AppServiceTransaction:
         """Atomically creates a new transaction for this application service
         with the given list of events. Ephemeral events are NOT persisted to the
@@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore(
             service: The service who the transaction is for.
             events: A list of persistent events to put in the transaction.
             ephemeral: A list of ephemeral events to put in the transaction.
+            to_device_messages: A list of to-device messages to put in the transaction.
 
         Returns:
             A new transaction.
@@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore(
                 (service.id, new_txn_id, event_ids),
             )
             return AppServiceTransaction(
-                service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+                service=service,
+                id=new_txn_id,
+                events=events,
+                ephemeral=ephemeral,
+                to_device_messages=to_device_messages,
             )
 
         return await self.db_pool.runInteraction(
@@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore(
         events = await self.get_events_as_list(event_ids)
 
         return AppServiceTransaction(
-            service=service, id=entry["txn_id"], events=events, ephemeral=[]
+            service=service,
+            id=entry["txn_id"],
+            events=events,
+            ephemeral=[],
+            to_device_messages=[],
         )
 
     def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -391,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore(
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
-        if type not in ("read_receipt", "presence"):
+        if type not in ("read_receipt", "presence", "to_device"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (type,)
@@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore(
             "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
         )
 
-    async def set_type_stream_id_for_appservice(
+    async def set_appservice_stream_type_pos(
         self, service: ApplicationService, stream_type: str, pos: Optional[int]
     ) -> None:
-        if stream_type not in ("read_receipt", "presence"):
+        if stream_type not in ("read_receipt", "presence", "to_device"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (stream_type,)
             )
 
-        def set_type_stream_id_for_appservice_txn(txn):
+        def set_appservice_stream_type_pos_txn(txn):
             stream_id_type = "%s_stream_id" % stream_type
             txn.execute(
                 "UPDATE application_services_state SET %s = ? WHERE as_id=?"
@@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore(
             )
 
         await self.db_pool.runInteraction(
-            "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+            "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
         )
 
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 4eca97189b..8801b7b2dd 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -24,6 +24,7 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    make_in_list_sql_clause,
 )
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
@@ -136,63 +137,260 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
-    async def get_new_messages_for_device(
+    async def get_messages_for_user_devices(
+        self,
+        user_ids: Collection[str],
+        from_stream_id: int,
+        to_stream_id: int,
+    ) -> Dict[Tuple[str, str], List[JsonDict]]:
+        """
+        Retrieve to-device messages for a given set of users.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
+        Args:
+            user_ids: The users to retrieve to-device messages for.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+
+        Returns:
+            A dictionary of (user id, device id) -> list of to-device messages.
+        """
+        # We expect the stream ID returned by _get_device_messages to always
+        # be to_stream_id. So, no need to return it from this function.
+        (
+            user_id_device_id_to_messages,
+            last_processed_stream_id,
+        ) = await self._get_device_messages(
+            user_ids=user_ids,
+            from_stream_id=from_stream_id,
+            to_stream_id=to_stream_id,
+        )
+
+        assert (
+            last_processed_stream_id == to_stream_id
+        ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
+
+        return user_id_device_id_to_messages
+
+    async def get_messages_for_device(
         self,
         user_id: str,
-        device_id: Optional[str],
-        last_stream_id: int,
-        current_stream_id: int,
+        device_id: str,
+        from_stream_id: int,
+        to_stream_id: int,
         limit: int = 100,
-    ) -> Tuple[List[dict], int]:
+    ) -> Tuple[List[JsonDict], int]:
         """
+        Retrieve to-device messages for a single user device.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
         Args:
-            user_id: The recipient user_id.
-            device_id: The recipient device_id.
-            last_stream_id: The last stream ID checked.
-            current_stream_id: The current position of the to device
-                message stream.
-            limit: The maximum number of messages to retrieve.
+            user_id: The ID of the user to retrieve messages for.
+            device_id: The ID of the device to retrieve to-device messages for.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+            limit: A limit on the number of to-device messages returned.
 
         Returns:
             A tuple containing:
-                * A list of messages for the device.
-                * The max stream token of these messages. There may be more to retrieve
-                  if the given limit was reached.
+                * A list of to-device messages within the given stream id range intended for
+                  the given user / device combo.
+                * The last-processed stream ID. Subsequent calls of this function with the
+                  same device should pass this value as 'from_stream_id'.
         """
-        has_changed = self._device_inbox_stream_cache.has_entity_changed(
-            user_id, last_stream_id
+        (
+            user_id_device_id_to_messages,
+            last_processed_stream_id,
+        ) = await self._get_device_messages(
+            user_ids=[user_id],
+            device_id=device_id,
+            from_stream_id=from_stream_id,
+            to_stream_id=to_stream_id,
+            limit=limit,
         )
-        if not has_changed:
-            return [], current_stream_id
 
-        def get_new_messages_for_device_txn(txn):
-            sql = (
-                "SELECT stream_id, message_json FROM device_inbox"
-                " WHERE user_id = ? AND device_id = ?"
-                " AND ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-                " LIMIT ?"
+        if not user_id_device_id_to_messages:
+            # There were no messages!
+            return [], to_stream_id
+
+        # Extract the messages, no need to return the user and device ID again
+        to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+
+        return to_device_messages, last_processed_stream_id
+
+    async def _get_device_messages(
+        self,
+        user_ids: Collection[str],
+        from_stream_id: int,
+        to_stream_id: int,
+        device_id: Optional[str] = None,
+        limit: Optional[int] = None,
+    ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
+        """
+        Retrieve pending to-device messages for a collection of user devices.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
+        Note that a stream ID can be shared by multiple copies of the same message with
+        different recipient devices. Stream IDs are only unique in the context of a single
+        user ID / device ID pair. Thus, applying a limit (of messages to return) when working
+        with a sliding window of stream IDs is only possible when querying messages of a
+        single user device.
+
+        Finally, note that device IDs are not unique across users.
+
+        Args:
+            user_ids: The user IDs to filter device messages by.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+            device_id: A device ID to query to-device messages for. If not provided, to-device
+                messages from all device IDs for the given user IDs will be queried. May not be
+                provided if `user_ids` contains more than one entry.
+            limit: The maximum number of to-device messages to return. Can only be used when
+                passing a single user ID / device ID tuple.
+
+        Returns:
+            A tuple containing:
+                * A dict of (user_id, device_id) -> list of to-device messages
+                * The last-processed stream ID. If this is less than `to_stream_id`, then
+                    there may be more messages to retrieve. If `limit` is not set, then this
+                    is always equal to 'to_stream_id'.
+        """
+        if not user_ids:
+            logger.warning("No users provided upon querying for device IDs")
+            return {}, to_stream_id
+
+        # Prevent a query for one user's device also retrieving another user's device with
+        # the same device ID (device IDs are not unique across users).
+        if len(user_ids) > 1 and device_id is not None:
+            raise AssertionError(
+                "Programming error: 'device_id' cannot be supplied to "
+                "_get_device_messages when >1 user_id has been provided"
             )
-            txn.execute(
-                sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
+
+        # A limit can only be applied when querying for a single user ID / device ID tuple.
+        # See the docstring of this function for more details.
+        if limit is not None and device_id is None:
+            raise AssertionError(
+                "Programming error: _get_device_messages was passed 'limit' "
+                "without a specific user_id/device_id"
             )
 
-            messages = []
-            stream_pos = current_stream_id
+        user_ids_to_query: Set[str] = set()
+        device_ids_to_query: Set[str] = set()
+
+        # Note that a device ID could be an empty str
+        if device_id is not None:
+            # If a device ID was passed, use it to filter results.
+            # Otherwise, device IDs will be derived from the given collection of user IDs.
+            device_ids_to_query.add(device_id)
+
+        # Determine which users have devices with pending messages
+        for user_id in user_ids:
+            if self._device_inbox_stream_cache.has_entity_changed(
+                user_id, from_stream_id
+            ):
+                # This user has new messages sent to them. Query messages for them
+                user_ids_to_query.add(user_id)
+
+        def get_device_messages_txn(txn: LoggingTransaction):
+            # Build a query to select messages from any of the given devices that
+            # are between the given stream id bounds.
+
+            # If a list of device IDs was not provided, retrieve all devices IDs
+            # for the given users. We explicitly do not query hidden devices, as
+            # hidden devices should not receive to-device messages.
+            # Note that this is more efficient than just dropping `device_id` from the query,
+            # since device_inbox has an index on `(user_id, device_id, stream_id)`
+            if not device_ids_to_query:
+                user_device_dicts = self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="devices",
+                    column="user_id",
+                    iterable=user_ids_to_query,
+                    keyvalues={"user_id": user_id, "hidden": False},
+                    retcols=("device_id",),
+                )
 
-            for row in txn:
-                stream_pos = row[0]
-                messages.append(db_to_json(row[1]))
+                device_ids_to_query.update(
+                    {row["device_id"] for row in user_device_dicts}
+                )
 
-            # If the limit was not reached we know that there's no more data for this
-            # user/device pair up to current_stream_id.
-            if len(messages) < limit:
-                stream_pos = current_stream_id
+            if not device_ids_to_query:
+                # We've ended up with no devices to query.
+                return {}, to_stream_id
 
-            return messages, stream_pos
+            # We include both user IDs and device IDs in this query, as we have an index
+            # (device_inbox_user_stream_id) for them.
+            user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
+                self.database_engine, "user_id", user_ids_to_query
+            )
+            (
+                device_id_many_clause_sql,
+                device_id_many_clause_args,
+            ) = make_in_list_sql_clause(
+                self.database_engine, "device_id", device_ids_to_query
+            )
+
+            sql = f"""
+                SELECT stream_id, user_id, device_id, message_json FROM device_inbox
+                WHERE {user_id_many_clause_sql}
+                AND {device_id_many_clause_sql}
+                AND ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+            """
+            sql_args = (
+                *user_id_many_clause_args,
+                *device_id_many_clause_args,
+                from_stream_id,
+                to_stream_id,
+            )
+
+            # If a limit was provided, limit the data retrieved from the database
+            if limit is not None:
+                sql += "LIMIT ?"
+                sql_args += (limit,)
+
+            txn.execute(sql, sql_args)
+
+            # Create and fill a dictionary of (user ID, device ID) -> list of messages
+            # intended for each device.
+            last_processed_stream_pos = to_stream_id
+            recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
+            for row in txn:
+                last_processed_stream_pos = row[0]
+                recipient_user_id = row[1]
+                recipient_device_id = row[2]
+                message_dict = db_to_json(row[3])
+
+                # Store the device details
+                recipient_device_to_messages.setdefault(
+                    (recipient_user_id, recipient_device_id), []
+                ).append(message_dict)
+
+            if limit is not None and txn.rowcount == limit:
+                # We ended up bumping up against the message limit. There may be more messages
+                # to retrieve. Return what we have, as well as the last stream position that
+                # was processed.
+                #
+                # The caller is expected to set this as the lower (exclusive) bound
+                # for the next query of this device.
+                return recipient_device_to_messages, last_processed_stream_pos
+
+            # The limit was not reached, thus we know that recipient_device_to_messages
+            # contains all to-device messages for the given device and stream id range.
+            #
+            # We return to_stream_id, which the caller should then provide as the lower
+            # (exclusive) bound on the next query of this device.
+            return recipient_device_to_messages, to_stream_id
 
         return await self.db_pool.runInteraction(
-            "get_new_messages_for_device", get_new_messages_for_device_txn
+            "get_device_messages", get_device_messages_txn
         )
 
     @trace