summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2021-10-29 18:34:34 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2021-10-29 18:34:34 +0100
commiteeccdf0e98e6051b3505689c7ce01916de3a5694 (patch)
tree36673d925b3a8daa0371ddf4b90d9dcd06e0385e
parentNote that AS interest via room ID or alias isn't respected (diff)
downloadsynapse-eeccdf0e98e6051b3505689c7ce01916de3a5694.tar.xz
to-device messages are tentatively working
-rw-r--r--synapse/handlers/appservice.py143
-rw-r--r--synapse/handlers/devicemessage.py31
-rw-r--r--synapse/notifier.py4
-rw-r--r--synapse/storage/databases/main/appservice.py8
-rw-r--r--synapse/storage/databases/main/deviceinbox.py122
-rw-r--r--synapse/storage/schema/main/delta/65/02msc2409_add_device_id_appservice_stream_type.sql18
6 files changed, 292 insertions, 34 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 827884cbb0..c3fe10bfee 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -183,7 +183,7 @@ class ApplicationServicesHandler:
         self,
         stream_key: str,
         new_token: Optional[int],
-        users: Optional[Collection[Union[str, UserID]]] = None,
+        users: Collection[Union[str, UserID]],
     ) -> None:
         """
         This is called by the notifier in the background when
@@ -200,6 +200,8 @@ class ApplicationServicesHandler:
                 Appservices will only receive ephemeral events that fall within their
                 registered user and room namespaces.
 
+                TODO: Update this bit
+
                 Any other value for `stream_key` will cause this function to return early.
 
             new_token: The latest stream token.
@@ -208,21 +210,38 @@ class ApplicationServicesHandler:
         if not self.notify_appservices:
             return
 
-        if stream_key not in ("typing_key", "receipt_key", "presence_key"):
-            return
-
-        services = [
-            service
-            for service in self.store.get_app_services()
-            if service.supports_ephemeral
-        ]
-        if not services:
+        if stream_key in ("typing_key", "receipt_key", "presence_key"):
+            # Check whether there are any appservices which have registered to receive
+            # ephemeral events.
+            #
+            # Note that whether these events are actually relevant to these appservices
+            # is decided later on.
+            services = [
+                service
+                for service in self.store.get_app_services()
+                if service.supports_ephemeral
+            ]
+            if not services:
+                # Bail out early if none of the target appservices have explicitly registered
+                # to receive these ephemeral events.
+                return
+
+        elif stream_key == "to_device_key":
+            # Appservices do not need to register explicit support for receiving device list
+            # updates.
+            #
+            # Note that whether these events are actually relevant to these appservices is
+            # decided later on.
+            services = self.store.get_app_services()
+
+        else:
+            # This stream_key is not supported.
             return
 
         # We only start a new background process if necessary rather than
         # optimistically (to cut down on overhead).
         self._notify_interested_services_ephemeral(
-            services, stream_key, new_token, users or []
+            services, stream_key, new_token, users
         )
 
     @wrap_as_background_process("notify_interested_services_ephemeral")
@@ -233,7 +252,7 @@ class ApplicationServicesHandler:
         new_token: Optional[int],
         users: Collection[Union[str, UserID]],
     ) -> None:
-        logger.debug("Checking interested services for %s" % (stream_key))
+        logger.debug("Checking interested services for %s" % stream_key)
         with Measure(self.clock, "notify_interested_services_ephemeral"):
             for service in services:
                 # Only handle typing if we have the latest token
@@ -256,6 +275,8 @@ class ApplicationServicesHandler:
                     # Persist the latest handled stream token for this appservice
                     # TODO: We seem to update the stream token for each appservice,
                     #  even if sending the ephemeral events to the appservice failed.
+                    # This is expected for typing, receipt and presence, but will need
+                    # to be handled for device* streams.
                     await self.store.set_type_stream_id_for_appservice(
                         service, "read_receipt", new_token
                     )
@@ -270,6 +291,104 @@ class ApplicationServicesHandler:
                         service, "presence", new_token
                     )
 
+                elif stream_key == "to_device_key" and new_token is not None:
+                    events = await self._handle_to_device(service, new_token, users)
+                    if events:
+                        self.scheduler.submit_ephemeral_events_for_as(service, events)
+
+                    # Persist the latest handled stream token for this appservice
+                    await self.store.set_type_stream_id_for_appservice(
+                        service, "to_device", new_token
+                    )
+
+    async def _handle_to_device(
+        self,
+        service: ApplicationService,
+        new_token: int,
+        users: Collection[Union[str, UserID]],
+    ) -> List[JsonDict]:
+        """
+        Given an application service, determine which events it should receive
+        from those between the last-recorded typing event stream token for this
+        appservice and the given stream token.
+
+        Args:
+            service: The application service to check for which events it should receive.
+            new_token: The latest to-device event stream token.
+            users: The users that should receive new to-device messages.
+
+        Returns:
+            A list of JSON dictionaries containing data derived from the typing events that
+            should be sent to the given application service.
+        """
+        # Get the stream token that this application service has processed up until
+        # TODO: Is 'users' always going to be one user here? Sometimes it's the sender!
+
+        # TODO: DB migration to add a column for to_device to application_services_state table
+        from_key = await self.store.get_type_stream_id_for_appservice(
+            service, "to_device"
+        )
+
+        # Filter out users that this appservice is not interested in
+        users_appservice_is_interested_in: List[str] = []
+        for user in users:
+            if isinstance(user, UserID):
+                user = user.to_string()
+
+            if service.is_interested_in_user(user):
+                users_appservice_is_interested_in.append(user)
+
+        if not users_appservice_is_interested_in:
+            # Return early if the AS was not interested in any of these users
+            return []
+
+        # Retrieve the to-device messages for each user
+        (
+            recipient_user_id_device_id_to_messages,
+            max_stream_token,
+        ) = await self.store.get_new_messages(
+            users_appservice_is_interested_in, from_key, new_token, limit=100
+        )
+
+        logger.info(
+            "*** Users: %s, from: %s, to: %s",
+            users_appservice_is_interested_in,
+            from_key,
+            new_token,
+        )
+        logger.info(
+            "*** Got to-device message: %s", recipient_user_id_device_id_to_messages
+        )
+
+        # TODO: Keep pulling out if max_stream_token != new_token?
+
+        # According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields
+        # to the event JSON so that the application service will know which user/device
+        # combination this messages was intended for.
+        #
+        # So we mangle this dict into a flat list of to-device messages with the relevant
+        # user ID and device ID embedded inside each message dict.
+        message_payload: List[JsonDict] = []
+        for (
+            user_id,
+            device_id,
+        ), messages in recipient_user_id_device_id_to_messages.items():
+            for message_json in messages:
+                # Remove 'message_id' from the to-device message, as it's an internal ID
+                message_json.pop("message_id", None)
+
+                message_payload.append(
+                    {
+                        "to_user_id": user_id,
+                        "to_device_id": device_id,
+                        **message_json,
+                    }
+                )
+
+        logger.info("*** Ended up with messages: %s", message_payload)
+
+        return message_payload
+
     async def _handle_typing(
         self, service: ApplicationService, new_token: int
     ) -> List[JsonDict]:
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index b6a2a34ab7..bf03f1b557 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -89,6 +89,13 @@ class DeviceMessageHandler:
         )
 
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
+        """
+        Handle receiving to-device messages from remote homeservers.
+
+        Args:
+            origin: The remote homeserver.
+            content: The JSON dictionary containing the to-device messages.
+        """
         local_messages = {}
         sender_user_id = content["sender"]
         if origin != get_domain_from_id(sender_user_id):
@@ -135,12 +142,16 @@ class DeviceMessageHandler:
                 message_type, sender_user_id, by_device
             )
 
-        stream_id = await self.store.add_messages_from_remote_to_device_inbox(
+        # Add messages to the database.
+        # Retrieve the stream token of the last-processed to-device message.
+        max_stream_token = await self.store.add_messages_from_remote_to_device_inbox(
             origin, message_id, local_messages
         )
 
+        # Notify listeners that there are new to-device messages to process,
+        # handing them the latest stream token.
         self.notifier.on_new_event(
-            "to_device_key", stream_id, users=local_messages.keys()
+            "to_device_key", max_stream_token, users=local_messages.keys()
         )
 
     async def _check_for_unknown_devices(
@@ -195,6 +206,14 @@ class DeviceMessageHandler:
         message_type: str,
         messages: Dict[str, Dict[str, JsonDict]],
     ) -> None:
+        """
+        Handle a request from a user to send to-device message(s).
+
+        Args:
+            requester: The user that is sending the to-device messages.
+            message_type: The type of to-device messages that are being sent.
+            messages: A dictionary containing recipients mapped to messages intended for them.
+        """
         sender_user_id = requester.user.to_string()
 
         message_id = random_string(16)
@@ -257,12 +276,16 @@ class DeviceMessageHandler:
                 "org.matrix.opentracing_context": json_encoder.encode(context),
             }
 
-        stream_id = await self.store.add_messages_to_device_inbox(
+        # Add messages to the database.
+        # Retrieve the stream token of the last-processed to-device message.
+        max_stream_token = await self.store.add_messages_to_device_inbox(
             local_messages, remote_edu_contents
         )
 
+        # Notify listeners that there are new to-device messages to process,
+        # handing them the latest stream token.
         self.notifier.on_new_event(
-            "to_device_key", stream_id, users=local_messages.keys()
+            "to_device_key", max_stream_token, users=local_messages.keys()
         )
 
         if self.federation_sender:
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 1acd899fab..d9dc50e146 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -378,7 +378,7 @@ class Notifier:
         self,
         stream_key: str,
         new_token: Union[int, RoomStreamToken],
-        users: Optional[Collection[Union[str, UserID]]] = None,
+        users: Collection[Union[str, UserID]],
     ) -> None:
         """Notify application services of ephemeral event activity.
 
@@ -392,7 +392,7 @@ class Notifier:
             if isinstance(new_token, int):
                 stream_token = new_token
             self.appservice_handler.notify_interested_services_ephemeral(
-                stream_key, stream_token, users or []
+                stream_key, stream_token, users
             )
         except Exception:
             logger.exception("Error notifying application services of event")
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 2da2659f41..63dbfd11fc 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -387,7 +387,7 @@ class ApplicationServiceTransactionWorkerStore(
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
-        if type not in ("read_receipt", "presence"):
+        if type not in ("read_receipt", "presence", "to_device"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (type,)
@@ -412,16 +412,16 @@ class ApplicationServiceTransactionWorkerStore(
         )
 
     async def set_type_stream_id_for_appservice(
-        self, service: ApplicationService, type: str, pos: Optional[int]
+        self, service: ApplicationService, stream_type: str, pos: Optional[int]
     ) -> None:
-        if type not in ("read_receipt", "presence"):
+        if stream_type not in ("read_receipt", "presence", "to_device"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (type,)
             )
 
         def set_type_stream_id_for_appservice_txn(txn):
-            stream_id_type = "%s_stream_id" % type
+            stream_id_type = "%s_stream_id" % stream_type
             txn.execute(
                 "UPDATE application_services_state SET %s = ? WHERE as_id=?"
                 % stream_id_type,
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3154906d45..d20cc1816e 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -13,15 +13,16 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple
+from typing import Collection, Dict, List, Optional, Tuple
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.replication.tcp.streams import ToDeviceStream
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, make_in_list_sql_clause
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -112,31 +113,125 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
+    async def get_new_messages(
+        self,
+        user_ids: Collection[str],
+        from_stream_token: int,
+        to_stream_token: int,
+        limit: int = 100,
+    ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
+        """
+        Retrieve to-device messages for a given set of user IDs.
+
+        Only to-device messages with stream tokens between the given boundaries
+        (from < X <= to) are returned.
+
+        Note that multiple messages can have the same stream token. Stream tokens are
+        unique to *messages*, but there might be multiple recipients of a message, and
+        thus multiple entries in the device_inbox table with the same stream token.
+
+        Args:
+            user_ids: The users to retrieve to-device messages for.
+            from_stream_token: The lower boundary of stream token to filter with (exclusive).
+            to_stream_token: The upper boundary of stream token to filter with (inclusive).
+            limit: The maximum number of to-device messages to return.
+
+        Returns:
+            A tuple containing the following:
+                * A list of to-device messages
+        """
+        # Bail out if none of these users have any messages
+        for user_id in user_ids:
+            if self._device_inbox_stream_cache.has_entity_changed(
+                user_id, from_stream_token
+            ):
+                break
+        else:
+            logger.info("*** Bailing out")
+            return {}, to_stream_token
+
+        def get_new_messages_txn(txn):
+            # Build a query to select messages from any of the given users that are between
+            # the given stream token bounds
+            sql = "SELECT stream_id, user_id, device_id, message_json FROM device_inbox"
+
+            # Scope to only the given users. We need to use this method as doing so is
+            # different across database engines.
+            many_clause_sql, many_clause_args = make_in_list_sql_clause(
+                self.database_engine, "user_id", user_ids
+            )
+
+            sql += (
+                " WHERE %s"
+                " AND ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+                " LIMIT ?"
+            ) % many_clause_sql
+
+            logger.info("*** %s\n\n%s", many_clause_sql, many_clause_args)
+            logger.info("*** %s", sql)
+            txn.execute(
+                sql, (*many_clause_args, from_stream_token, to_stream_token, limit)
+            )
+
+            # Create a dictionary of (user ID, device ID) -> list of messages that
+            # that device is meant to receive.
+            recipient_user_id_device_id_to_messages = {}
+
+            stream_pos = to_stream_token
+            total_messages_processed = 0
+            for row in txn:
+                # Record the last-processed stream position, to return later.
+                # Note that we process messages here in order of ascending stream token.
+                stream_pos = row[0]
+                recipient_user_id = row[1]
+                recipient_device_id = row[2]
+                message_dict = db_to_json(row[3])
+
+                recipient_user_id_device_id_to_messages.setdefault(
+                    (recipient_user_id, recipient_device_id), []
+                ).append(message_dict)
+                total_messages_processed += 1
+
+            # This is needed (REVIEW: I think) as you can have multiple rows for a
+            # single to-device message (due to multiple recipients).
+            if total_messages_processed < limit:
+                stream_pos = to_stream_token
+
+            return recipient_user_id_device_id_to_messages, stream_pos
+
+        return await self.db_pool.runInteraction(
+            "get_new_messages", get_new_messages_txn
+        )
+
     async def get_new_messages_for_device(
         self,
         user_id: str,
         device_id: Optional[str],
-        last_stream_id: int,
-        current_stream_id: int,
+        last_stream_token: int,
+        current_stream_token: int,
         limit: int = 100,
     ) -> Tuple[List[dict], int]:
         """
         Args:
             user_id: The recipient user_id.
             device_id: The recipient device_id.
-            last_stream_id: The last stream ID checked.
-            current_stream_id: The current position of the to device
+            last_stream_token: The last stream ID checked.
+            current_stream_token: The current position of the to device
                 message stream.
             limit: The maximum number of messages to retrieve.
 
         Returns:
-            A list of messages for the device and where in the stream the messages got to.
+            A tuple containing:
+                * A list of messages for the device
+                * The max stream token of these messages. There may be more to retrieve
+                  if the given limit was reached.
         """
         has_changed = self._device_inbox_stream_cache.has_entity_changed(
-            user_id, last_stream_id
+            user_id, last_stream_token
         )
         if not has_changed:
-            return [], current_stream_id
+            return [], current_stream_token
 
         def get_new_messages_for_device_txn(txn):
             sql = (
@@ -147,14 +242,17 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(
-                sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
+                sql,
+                (user_id, device_id, last_stream_token, current_stream_token, limit),
             )
             messages = []
+            stream_pos = current_stream_token
+
             for row in txn:
                 stream_pos = row[0]
                 messages.append(db_to_json(row[1]))
             if len(messages) < limit:
-                stream_pos = current_stream_id
+                stream_pos = current_stream_token
             return messages, stream_pos
 
         return await self.db_pool.runInteraction(
@@ -369,7 +467,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
         Args:
             local_messages_by_user_and_device:
-                Dictionary of user_id to device_id to message.
+                Dictionary of recipient user_id to recipient device_id to message.
             remote_messages_by_destination:
                 Dictionary of destination server_name to the EDU JSON to send.
 
diff --git a/synapse/storage/schema/main/delta/65/02msc2409_add_device_id_appservice_stream_type.sql b/synapse/storage/schema/main/delta/65/02msc2409_add_device_id_appservice_stream_type.sql
new file mode 100644
index 0000000000..ba88ca03c2
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/02msc2409_add_device_id_appservice_stream_type.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column to track what to_device stream token that this application
+-- service has been caught up to.
+ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT;
\ No newline at end of file