summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/11247.misc1
-rw-r--r--synapse/handlers/appservice.py24
-rw-r--r--synapse/handlers/devicemessage.py31
-rw-r--r--synapse/storage/databases/main/appservice.py8
-rw-r--r--synapse/storage/databases/main/deviceinbox.py23
-rw-r--r--tests/handlers/test_appservice.py8
6 files changed, 78 insertions, 17 deletions
diff --git a/changelog.d/11247.misc b/changelog.d/11247.misc
new file mode 100644
index 0000000000..5ce701560e
--- /dev/null
+++ b/changelog.d/11247.misc
@@ -0,0 +1 @@
+Clean up code relating to to-device messages and sending ephemeral events to application services.
\ No newline at end of file
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index ddc9105ee9..9abdad262b 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -188,7 +188,7 @@ class ApplicationServicesHandler:
         self,
         stream_key: str,
         new_token: Union[int, RoomStreamToken],
-        users: Optional[Collection[Union[str, UserID]]] = None,
+        users: Collection[Union[str, UserID]],
     ) -> None:
         """
         This is called by the notifier in the background when an ephemeral event is handled
@@ -203,7 +203,9 @@ class ApplicationServicesHandler:
                 value for `stream_key` will cause this function to return early.
 
                 Ephemeral events will only be pushed to appservices that have opted into
-                them.
+                receiving them by setting `push_ephemeral` to true in their registration
+                file. Note that while MSC2409 is experimental, this option is called
+                `de.sorunome.msc2409.push_ephemeral`.
 
                 Appservices will only receive ephemeral events that fall within their
                 registered user and room namespaces.
@@ -214,6 +216,7 @@ class ApplicationServicesHandler:
         if not self.notify_appservices:
             return
 
+        # Ignore any unsupported streams
         if stream_key not in ("typing_key", "receipt_key", "presence_key"):
             return
 
@@ -230,18 +233,25 @@ class ApplicationServicesHandler:
         # Additional context: https://github.com/matrix-org/synapse/pull/11137
         assert isinstance(new_token, int)
 
+        # Check whether there are any appservices which have registered to receive
+        # ephemeral events.
+        #
+        # Note that whether these events are actually relevant to these appservices
+        # is decided later on.
         services = [
             service
             for service in self.store.get_app_services()
             if service.supports_ephemeral
         ]
         if not services:
+            # Bail out early if none of the target appservices have explicitly registered
+            # to receive these ephemeral events.
             return
 
         # We only start a new background process if necessary rather than
         # optimistically (to cut down on overhead).
         self._notify_interested_services_ephemeral(
-            services, stream_key, new_token, users or []
+            services, stream_key, new_token, users
         )
 
     @wrap_as_background_process("notify_interested_services_ephemeral")
@@ -252,7 +262,7 @@ class ApplicationServicesHandler:
         new_token: int,
         users: Collection[Union[str, UserID]],
     ) -> None:
-        logger.debug("Checking interested services for %s" % (stream_key))
+        logger.debug("Checking interested services for %s", stream_key)
         with Measure(self.clock, "notify_interested_services_ephemeral"):
             for service in services:
                 if stream_key == "typing_key":
@@ -345,6 +355,9 @@ class ApplicationServicesHandler:
 
         Args:
             service: The application service to check for which events it should receive.
+            new_token: A receipts event stream token. Purely used to double-check that the
+                from_token we pull from the database isn't greater than or equal to this
+                token. Prevents accidentally duplicating work.
 
         Returns:
             A list of JSON dictionaries containing data derived from the read receipts that
@@ -382,6 +395,9 @@ class ApplicationServicesHandler:
         Args:
             service: The application service that ephemeral events are being sent to.
             users: The users that should receive the presence update.
+            new_token: A presence update stream token. Purely used to double-check that the
+                from_token we pull from the database isn't greater than or equal to this
+                token. Prevents accidentally duplicating work.
 
         Returns:
             A list of json dictionaries containing data derived from the presence events
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index b6a2a34ab7..b582266af9 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -89,6 +89,13 @@ class DeviceMessageHandler:
         )
 
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
+        """
+        Handle receiving to-device messages from remote homeservers.
+
+        Args:
+            origin: The remote homeserver.
+            content: The JSON dictionary containing the to-device messages.
+        """
         local_messages = {}
         sender_user_id = content["sender"]
         if origin != get_domain_from_id(sender_user_id):
@@ -135,12 +142,16 @@ class DeviceMessageHandler:
                 message_type, sender_user_id, by_device
             )
 
-        stream_id = await self.store.add_messages_from_remote_to_device_inbox(
+        # Add messages to the database.
+        # Retrieve the stream id of the last-processed to-device message.
+        last_stream_id = await self.store.add_messages_from_remote_to_device_inbox(
             origin, message_id, local_messages
         )
 
+        # Notify listeners that there are new to-device messages to process,
+        # handing them the latest stream id.
         self.notifier.on_new_event(
-            "to_device_key", stream_id, users=local_messages.keys()
+            "to_device_key", last_stream_id, users=local_messages.keys()
         )
 
     async def _check_for_unknown_devices(
@@ -195,6 +206,14 @@ class DeviceMessageHandler:
         message_type: str,
         messages: Dict[str, Dict[str, JsonDict]],
     ) -> None:
+        """
+        Handle a request from a user to send to-device message(s).
+
+        Args:
+            requester: The user that is sending the to-device messages.
+            message_type: The type of to-device messages that are being sent.
+            messages: A dictionary containing recipients mapped to messages intended for them.
+        """
         sender_user_id = requester.user.to_string()
 
         message_id = random_string(16)
@@ -257,12 +276,16 @@ class DeviceMessageHandler:
                 "org.matrix.opentracing_context": json_encoder.encode(context),
             }
 
-        stream_id = await self.store.add_messages_to_device_inbox(
+        # Add messages to the database.
+        # Retrieve the stream id of the last-processed to-device message.
+        last_stream_id = await self.store.add_messages_to_device_inbox(
             local_messages, remote_edu_contents
         )
 
+        # Notify listeners that there are new to-device messages to process,
+        # handing them the latest stream id.
         self.notifier.on_new_event(
-            "to_device_key", stream_id, users=local_messages.keys()
+            "to_device_key", last_stream_id, users=local_messages.keys()
         )
 
         if self.federation_sender:
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 2da2659f41..baec35ee27 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -412,16 +412,16 @@ class ApplicationServiceTransactionWorkerStore(
         )
 
     async def set_type_stream_id_for_appservice(
-        self, service: ApplicationService, type: str, pos: Optional[int]
+        self, service: ApplicationService, stream_type: str, pos: Optional[int]
     ) -> None:
-        if type not in ("read_receipt", "presence"):
+        if stream_type not in ("read_receipt", "presence"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
-                % (type,)
+                % (stream_type,)
             )
 
         def set_type_stream_id_for_appservice_txn(txn):
-            stream_id_type = "%s_stream_id" % type
+            stream_id_type = "%s_stream_id" % stream_type
             txn.execute(
                 "UPDATE application_services_state SET %s = ? WHERE as_id=?"
                 % stream_id_type,
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 264e625bd7..ae3afdd5d2 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -134,7 +134,10 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             limit: The maximum number of messages to retrieve.
 
         Returns:
-            A list of messages for the device and where in the stream the messages got to.
+            A tuple containing:
+                * A list of messages for the device.
+                * The max stream token of these messages. There may be more to retrieve
+                  if the given limit was reached.
         """
         has_changed = self._device_inbox_stream_cache.has_entity_changed(
             user_id, last_stream_id
@@ -153,12 +156,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             txn.execute(
                 sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
             )
+
             messages = []
+            stream_pos = current_stream_id
+
             for row in txn:
                 stream_pos = row[0]
                 messages.append(db_to_json(row[1]))
+
+            # If the limit was not reached we know that there's no more data for this
+            # user/device pair up to current_stream_id.
             if len(messages) < limit:
                 stream_pos = current_stream_id
+
             return messages, stream_pos
 
         return await self.db_pool.runInteraction(
@@ -260,13 +270,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 " LIMIT ?"
             )
             txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
+
             messages = []
+            stream_pos = current_stream_id
+
             for row in txn:
                 stream_pos = row[0]
                 messages.append(db_to_json(row[1]))
+
+            # If the limit was not reached we know that there's no more data for this
+            # user/device pair up to current_stream_id.
             if len(messages) < limit:
                 log_kv({"message": "Set stream position to current position"})
                 stream_pos = current_stream_id
+
             return messages, stream_pos
 
         return await self.db_pool.runInteraction(
@@ -372,8 +389,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         """Used to send messages from this server.
 
         Args:
-            local_messages_by_user_and_device:
-                Dictionary of user_id to device_id to message.
+            local_messages_by_user_then_device:
+                Dictionary of recipient user_id to recipient device_id to message.
             remote_messages_by_destination:
                 Dictionary of destination server_name to the EDU JSON to send.
 
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1f6a924452..d6f14e2dba 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -272,7 +272,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             make_awaitable(([event], None))
         )
 
-        self.handler.notify_interested_services_ephemeral("receipt_key", 580)
+        self.handler.notify_interested_services_ephemeral(
+            "receipt_key", 580, ["@fakerecipient:example.com"]
+        )
         self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with(
             interested_service, [event]
         )
@@ -300,7 +302,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             make_awaitable(([event], None))
         )
 
-        self.handler.notify_interested_services_ephemeral("receipt_key", 579)
+        self.handler.notify_interested_services_ephemeral(
+            "receipt_key", 580, ["@fakerecipient:example.com"]
+        )
         self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called()
 
     def _mkservice(self, is_interested, protocols=None):