summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorAndrew Morgan <1342360+anoadragon453@users.noreply.github.com>2022-02-01 14:13:38 +0000
committerGitHub <noreply@github.com>2022-02-01 14:13:38 +0000
commit64ec45fc1b0856dc7daacca7d3ab75d50bd89f84 (patch)
tree84b2a6733967600b468a72785da19f43a9b40299 /synapse
parentDon't mention 3.6 EOL under misc (diff)
downloadsynapse-64ec45fc1b0856dc7daacca7d3ab75d50bd89f84.tar.xz
Send to-device messages to application services (#11215)
Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Diffstat (limited to 'synapse')
-rw-r--r--synapse/appservice/__init__.py3
-rw-r--r--synapse/appservice/api.py29
-rw-r--r--synapse/appservice/scheduler.py97
-rw-r--r--synapse/config/experimental.py7
-rw-r--r--synapse/handlers/appservice.py136
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/notifier.py4
-rw-r--r--synapse/storage/databases/main/appservice.py24
-rw-r--r--synapse/storage/databases/main/deviceinbox.py276
-rw-r--r--synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql21
10 files changed, 507 insertions, 94 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 8c9ff93b2c..7dbebd97b5 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -351,11 +351,13 @@ class AppServiceTransaction:
         id: int,
         events: List[EventBase],
         ephemeral: List[JsonDict],
+        to_device_messages: List[JsonDict],
     ):
         self.service = service
         self.id = id
         self.events = events
         self.ephemeral = ephemeral
+        self.to_device_messages = to_device_messages
 
     async def send(self, as_api: "ApplicationServiceApi") -> bool:
         """Sends this transaction using the provided AS API interface.
@@ -369,6 +371,7 @@ class AppServiceTransaction:
             service=self.service,
             events=self.events,
             ephemeral=self.ephemeral,
+            to_device_messages=self.to_device_messages,
             txn_id=self.id,
         )
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index def4424af0..73be7ff3d4 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -218,8 +218,23 @@ class ApplicationServiceApi(SimpleHttpClient):
         service: "ApplicationService",
         events: List[EventBase],
         ephemeral: List[JsonDict],
+        to_device_messages: List[JsonDict],
         txn_id: Optional[int] = None,
     ) -> bool:
+        """
+        Push data to an application service.
+
+        Args:
+            service: The application service to send to.
+            events: The persistent events to send.
+            ephemeral: The ephemeral events to send.
+            to_device_messages: The to-device messages to send.
+            txn_id: An unique ID to assign to this transaction. Application services should
+                deduplicate transactions received with identitical IDs.
+
+        Returns:
+            True if the task succeeded, False if it failed.
+        """
         if service.url is None:
             return True
 
@@ -237,13 +252,15 @@ class ApplicationServiceApi(SimpleHttpClient):
         uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
 
         # Never send ephemeral events to appservices that do not support it
+        body: Dict[str, List[JsonDict]] = {"events": serialized_events}
         if service.supports_ephemeral:
-            body = {
-                "events": serialized_events,
-                "de.sorunome.msc2409.ephemeral": ephemeral,
-            }
-        else:
-            body = {"events": serialized_events}
+            body.update(
+                {
+                    # TODO: Update to stable prefixes once MSC2409 completes FCP merge.
+                    "de.sorunome.msc2409.ephemeral": ephemeral,
+                    "de.sorunome.msc2409.to_device": to_device_messages,
+                }
+            )
 
         try:
             await self.put_json(
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 185e3a5278..c42fa32fff 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -48,7 +48,16 @@ This is all tied together by the AppServiceScheduler which DIs the required
 components.
 """
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Callable,
+    Collection,
+    Dict,
+    List,
+    Optional,
+    Set,
+)
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.appservice.api import ApplicationServiceApi
@@ -71,6 +80,9 @@ MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
 # Maximum number of ephemeral events to provide in an AS transaction.
 MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
 
+# Maximum number of to-device messages to provide in an AS transaction.
+MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100
+
 
 class ApplicationServiceScheduler:
     """Public facing API for this module. Does the required DI to tie the
@@ -97,15 +109,40 @@ class ApplicationServiceScheduler:
         for service in services:
             self.txn_ctrl.start_recoverer(service)
 
-    def submit_event_for_as(
-        self, service: ApplicationService, event: EventBase
+    def enqueue_for_appservice(
+        self,
+        appservice: ApplicationService,
+        events: Optional[Collection[EventBase]] = None,
+        ephemeral: Optional[Collection[JsonDict]] = None,
+        to_device_messages: Optional[Collection[JsonDict]] = None,
     ) -> None:
-        self.queuer.enqueue_event(service, event)
+        """
+        Enqueue some data to be sent off to an application service.
 
-    def submit_ephemeral_events_for_as(
-        self, service: ApplicationService, events: List[JsonDict]
-    ) -> None:
-        self.queuer.enqueue_ephemeral(service, events)
+        Args:
+            appservice: The application service to create and send a transaction to.
+            events: The persistent room events to send.
+            ephemeral: The ephemeral events to send.
+            to_device_messages: The to-device messages to send. These differ from normal
+                to-device messages sent to clients, as they have 'to_device_id' and
+                'to_user_id' fields.
+        """
+        # We purposefully allow this method to run with empty events/ephemeral
+        # collections, so that callers do not need to check iterable size themselves.
+        if not events and not ephemeral and not to_device_messages:
+            return
+
+        if events:
+            self.queuer.queued_events.setdefault(appservice.id, []).extend(events)
+        if ephemeral:
+            self.queuer.queued_ephemeral.setdefault(appservice.id, []).extend(ephemeral)
+        if to_device_messages:
+            self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
+                to_device_messages
+            )
+
+        # Kick off a new application service transaction
+        self.queuer.start_background_request(appservice)
 
 
 class _ServiceQueuer:
@@ -121,13 +158,15 @@ class _ServiceQueuer:
         self.queued_events: Dict[str, List[EventBase]] = {}
         # dict of {service_id: [events]}
         self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
+        # dict of {service_id: [to_device_message_json]}
+        self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
 
         # the appservices which currently have a transaction in flight
         self.requests_in_flight: Set[str] = set()
         self.txn_ctrl = txn_ctrl
         self.clock = clock
 
-    def _start_background_request(self, service: ApplicationService) -> None:
+    def start_background_request(self, service: ApplicationService) -> None:
         # start a sender for this appservice if we don't already have one
         if service.id in self.requests_in_flight:
             return
@@ -136,16 +175,6 @@ class _ServiceQueuer:
             "as-sender-%s" % (service.id,), self._send_request, service
         )
 
-    def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
-        self.queued_events.setdefault(service.id, []).append(event)
-        self._start_background_request(service)
-
-    def enqueue_ephemeral(
-        self, service: ApplicationService, events: List[JsonDict]
-    ) -> None:
-        self.queued_ephemeral.setdefault(service.id, []).extend(events)
-        self._start_background_request(service)
-
     async def _send_request(self, service: ApplicationService) -> None:
         # sanity-check: we shouldn't get here if this service already has a sender
         # running.
@@ -162,11 +191,21 @@ class _ServiceQueuer:
                 ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
                 del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
 
-                if not events and not ephemeral:
+                all_to_device_messages = self.queued_to_device_messages.get(
+                    service.id, []
+                )
+                to_device_messages_to_send = all_to_device_messages[
+                    :MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION
+                ]
+                del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
+
+                if not events and not ephemeral and not to_device_messages_to_send:
                     return
 
                 try:
-                    await self.txn_ctrl.send(service, events, ephemeral)
+                    await self.txn_ctrl.send(
+                        service, events, ephemeral, to_device_messages_to_send
+                    )
                 except Exception:
                     logger.exception("AS request failed")
         finally:
@@ -198,10 +237,24 @@ class _TransactionController:
         service: ApplicationService,
         events: List[EventBase],
         ephemeral: Optional[List[JsonDict]] = None,
+        to_device_messages: Optional[List[JsonDict]] = None,
     ) -> None:
+        """
+        Create a transaction with the given data and send to the provided
+        application service.
+
+        Args:
+            service: The application service to send the transaction to.
+            events: The persistent events to include in the transaction.
+            ephemeral: The ephemeral events to include in the transaction.
+            to_device_messages: The to-device messages to include in the transaction.
+        """
         try:
             txn = await self.store.create_appservice_txn(
-                service=service, events=events, ephemeral=ephemeral or []
+                service=service,
+                events=events,
+                ephemeral=ephemeral or [],
+                to_device_messages=to_device_messages or [],
             )
             service_is_up = await self._is_service_up(service)
             if service_is_up:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 65c807a19a..e4719d19b8 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -52,3 +52,10 @@ class ExperimentalConfig(Config):
         self.msc3202_device_masquerading_enabled: bool = experimental.get(
             "msc3202_device_masquerading", False
         )
+
+        # MSC2409 (this setting only relates to optionally sending to-device messages).
+        # Presence, typing and read receipt EDUs are already sent to application services that
+        # have opted in to receive them. If enabled, this adds to-device messages to that list.
+        self.msc2409_to_device_messages_enabled: bool = experimental.get(
+            "msc2409_to_device_messages_enabled", False
+        )
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 7833e77e2b..0fb919acf6 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -55,6 +55,9 @@ class ApplicationServicesHandler:
         self.clock = hs.get_clock()
         self.notify_appservices = hs.config.appservice.notify_appservices
         self.event_sources = hs.get_event_sources()
+        self._msc2409_to_device_messages_enabled = (
+            hs.config.experimental.msc2409_to_device_messages_enabled
+        )
 
         self.current_max = 0
         self.is_processing = False
@@ -132,7 +135,9 @@ class ApplicationServicesHandler:
 
                         # Fork off pushes to these services
                         for service in services:
-                            self.scheduler.submit_event_for_as(service, event)
+                            self.scheduler.enqueue_for_appservice(
+                                service, events=[event]
+                            )
 
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(event.event_id)
@@ -199,8 +204,9 @@ class ApplicationServicesHandler:
         Args:
             stream_key: The stream the event came from.
 
-                `stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other
-                value for `stream_key` will cause this function to return early.
+                `stream_key` can be "typing_key", "receipt_key", "presence_key" or
+                "to_device_key". Any other value for `stream_key` will cause this function
+                to return early.
 
                 Ephemeral events will only be pushed to appservices that have opted into
                 receiving them by setting `push_ephemeral` to true in their registration
@@ -216,8 +222,15 @@ class ApplicationServicesHandler:
         if not self.notify_appservices:
             return
 
-        # Ignore any unsupported streams
-        if stream_key not in ("typing_key", "receipt_key", "presence_key"):
+        # Notify appservices of updates in ephemeral event streams.
+        # Only the following streams are currently supported.
+        # FIXME: We should use constants for these values.
+        if stream_key not in (
+            "typing_key",
+            "receipt_key",
+            "presence_key",
+            "to_device_key",
+        ):
             return
 
         # Assert that new_token is an integer (and not a RoomStreamToken).
@@ -233,6 +246,13 @@ class ApplicationServicesHandler:
         # Additional context: https://github.com/matrix-org/synapse/pull/11137
         assert isinstance(new_token, int)
 
+        # Ignore to-device messages if the feature flag is not enabled
+        if (
+            stream_key == "to_device_key"
+            and not self._msc2409_to_device_messages_enabled
+        ):
+            return
+
         # Check whether there are any appservices which have registered to receive
         # ephemeral events.
         #
@@ -266,7 +286,7 @@ class ApplicationServicesHandler:
         with Measure(self.clock, "notify_interested_services_ephemeral"):
             for service in services:
                 if stream_key == "typing_key":
-                    # Note that we don't persist the token (via set_type_stream_id_for_appservice)
+                    # Note that we don't persist the token (via set_appservice_stream_type_pos)
                     # for typing_key due to performance reasons and due to their highly
                     # ephemeral nature.
                     #
@@ -274,7 +294,7 @@ class ApplicationServicesHandler:
                     # and, if they apply to this application service, send it off.
                     events = await self._handle_typing(service, new_token)
                     if events:
-                        self.scheduler.submit_ephemeral_events_for_as(service, events)
+                        self.scheduler.enqueue_for_appservice(service, ephemeral=events)
                     continue
 
                 # Since we read/update the stream position for this AS/stream
@@ -285,28 +305,37 @@ class ApplicationServicesHandler:
                 ):
                     if stream_key == "receipt_key":
                         events = await self._handle_receipts(service, new_token)
-                        if events:
-                            self.scheduler.submit_ephemeral_events_for_as(
-                                service, events
-                            )
+                        self.scheduler.enqueue_for_appservice(service, ephemeral=events)
 
                         # Persist the latest handled stream token for this appservice
-                        await self.store.set_type_stream_id_for_appservice(
+                        await self.store.set_appservice_stream_type_pos(
                             service, "read_receipt", new_token
                         )
 
                     elif stream_key == "presence_key":
                         events = await self._handle_presence(service, users, new_token)
-                        if events:
-                            self.scheduler.submit_ephemeral_events_for_as(
-                                service, events
-                            )
+                        self.scheduler.enqueue_for_appservice(service, ephemeral=events)
 
                         # Persist the latest handled stream token for this appservice
-                        await self.store.set_type_stream_id_for_appservice(
+                        await self.store.set_appservice_stream_type_pos(
                             service, "presence", new_token
                         )
 
+                    elif stream_key == "to_device_key":
+                        # Retrieve a list of to-device message events, as well as the
+                        # maximum stream token of the messages we were able to retrieve.
+                        to_device_messages = await self._get_to_device_messages(
+                            service, new_token, users
+                        )
+                        self.scheduler.enqueue_for_appservice(
+                            service, to_device_messages=to_device_messages
+                        )
+
+                        # Persist the latest handled stream token for this appservice
+                        await self.store.set_appservice_stream_type_pos(
+                            service, "to_device", new_token
+                        )
+
     async def _handle_typing(
         self, service: ApplicationService, new_token: int
     ) -> List[JsonDict]:
@@ -440,6 +469,79 @@ class ApplicationServicesHandler:
 
         return events
 
+    async def _get_to_device_messages(
+        self,
+        service: ApplicationService,
+        new_token: int,
+        users: Collection[Union[str, UserID]],
+    ) -> List[JsonDict]:
+        """
+        Given an application service, determine which events it should receive
+        from those between the last-recorded to-device message stream token for this
+        appservice and the given stream token.
+
+        Args:
+            service: The application service to check for which events it should receive.
+            new_token: The latest to-device event stream token.
+            users: The users to be notified for the new to-device messages
+                (ie, the recipients of the messages).
+
+        Returns:
+            A list of JSON dictionaries containing data derived from the to-device events
+                that should be sent to the given application service.
+        """
+        # Get the stream token that this application service has processed up until
+        from_key = await self.store.get_type_stream_id_for_appservice(
+            service, "to_device"
+        )
+
+        # Filter out users that this appservice is not interested in
+        users_appservice_is_interested_in: List[str] = []
+        for user in users:
+            # FIXME: We should do this farther up the call stack. We currently repeat
+            #  this operation in _handle_presence.
+            if isinstance(user, UserID):
+                user = user.to_string()
+
+            if service.is_interested_in_user(user):
+                users_appservice_is_interested_in.append(user)
+
+        if not users_appservice_is_interested_in:
+            # Return early if the AS was not interested in any of these users
+            return []
+
+        # Retrieve the to-device messages for each user
+        recipient_device_to_messages = await self.store.get_messages_for_user_devices(
+            users_appservice_is_interested_in,
+            from_key,
+            new_token,
+        )
+
+        # According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields
+        # to the event JSON so that the application service will know which user/device
+        # combination this messages was intended for.
+        #
+        # So we mangle this dict into a flat list of to-device messages with the relevant
+        # user ID and device ID embedded inside each message dict.
+        message_payload: List[JsonDict] = []
+        for (
+            user_id,
+            device_id,
+        ), messages in recipient_device_to_messages.items():
+            for message_json in messages:
+                # Remove 'message_id' from the to-device message, as it's an internal ID
+                message_json.pop("message_id", None)
+
+                message_payload.append(
+                    {
+                        "to_user_id": user_id,
+                        "to_device_id": device_id,
+                        **message_json,
+                    }
+                )
+
+        return message_payload
+
     async def query_user_exists(self, user_id: str) -> bool:
         """Check if any application service knows this user_id exists.
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c72ed7c290..aa9a76f8a9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1348,8 +1348,8 @@ class SyncHandler:
         if sync_result_builder.since_token is not None:
             since_stream_id = int(sync_result_builder.since_token.to_device_key)
 
-        if since_stream_id != int(now_token.to_device_key):
-            messages, stream_id = await self.store.get_new_messages_for_device(
+        if device_id is not None and since_stream_id != int(now_token.to_device_key):
+            messages, stream_id = await self.store.get_messages_for_device(
                 user_id, device_id, since_stream_id, now_token.to_device_key
             )
 
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 632b2245ef..5988c67d90 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -461,7 +461,9 @@ class Notifier:
                     users,
                 )
             except Exception:
-                logger.exception("Error notifying application services of event")
+                logger.exception(
+                    "Error notifying application services of ephemeral events"
+                )
 
     def on_new_replication_data(self) -> None:
         """Used to inform replication listeners that something has happened
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 2bb5288431..304814af5d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore(
         service: ApplicationService,
         events: List[EventBase],
         ephemeral: List[JsonDict],
+        to_device_messages: List[JsonDict],
     ) -> AppServiceTransaction:
         """Atomically creates a new transaction for this application service
         with the given list of events. Ephemeral events are NOT persisted to the
@@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore(
             service: The service who the transaction is for.
             events: A list of persistent events to put in the transaction.
             ephemeral: A list of ephemeral events to put in the transaction.
+            to_device_messages: A list of to-device messages to put in the transaction.
 
         Returns:
             A new transaction.
@@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore(
                 (service.id, new_txn_id, event_ids),
             )
             return AppServiceTransaction(
-                service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+                service=service,
+                id=new_txn_id,
+                events=events,
+                ephemeral=ephemeral,
+                to_device_messages=to_device_messages,
             )
 
         return await self.db_pool.runInteraction(
@@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore(
         events = await self.get_events_as_list(event_ids)
 
         return AppServiceTransaction(
-            service=service, id=entry["txn_id"], events=events, ephemeral=[]
+            service=service,
+            id=entry["txn_id"],
+            events=events,
+            ephemeral=[],
+            to_device_messages=[],
         )
 
     def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -391,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore(
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
-        if type not in ("read_receipt", "presence"):
+        if type not in ("read_receipt", "presence", "to_device"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (type,)
@@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore(
             "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
         )
 
-    async def set_type_stream_id_for_appservice(
+    async def set_appservice_stream_type_pos(
         self, service: ApplicationService, stream_type: str, pos: Optional[int]
     ) -> None:
-        if stream_type not in ("read_receipt", "presence"):
+        if stream_type not in ("read_receipt", "presence", "to_device"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (stream_type,)
             )
 
-        def set_type_stream_id_for_appservice_txn(txn):
+        def set_appservice_stream_type_pos_txn(txn):
             stream_id_type = "%s_stream_id" % stream_type
             txn.execute(
                 "UPDATE application_services_state SET %s = ? WHERE as_id=?"
@@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore(
             )
 
         await self.db_pool.runInteraction(
-            "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+            "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
         )
 
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 4eca97189b..8801b7b2dd 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -24,6 +24,7 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    make_in_list_sql_clause,
 )
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
@@ -136,63 +137,260 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
-    async def get_new_messages_for_device(
+    async def get_messages_for_user_devices(
+        self,
+        user_ids: Collection[str],
+        from_stream_id: int,
+        to_stream_id: int,
+    ) -> Dict[Tuple[str, str], List[JsonDict]]:
+        """
+        Retrieve to-device messages for a given set of users.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
+        Args:
+            user_ids: The users to retrieve to-device messages for.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+
+        Returns:
+            A dictionary of (user id, device id) -> list of to-device messages.
+        """
+        # We expect the stream ID returned by _get_device_messages to always
+        # be to_stream_id. So, no need to return it from this function.
+        (
+            user_id_device_id_to_messages,
+            last_processed_stream_id,
+        ) = await self._get_device_messages(
+            user_ids=user_ids,
+            from_stream_id=from_stream_id,
+            to_stream_id=to_stream_id,
+        )
+
+        assert (
+            last_processed_stream_id == to_stream_id
+        ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
+
+        return user_id_device_id_to_messages
+
+    async def get_messages_for_device(
         self,
         user_id: str,
-        device_id: Optional[str],
-        last_stream_id: int,
-        current_stream_id: int,
+        device_id: str,
+        from_stream_id: int,
+        to_stream_id: int,
         limit: int = 100,
-    ) -> Tuple[List[dict], int]:
+    ) -> Tuple[List[JsonDict], int]:
         """
+        Retrieve to-device messages for a single user device.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
         Args:
-            user_id: The recipient user_id.
-            device_id: The recipient device_id.
-            last_stream_id: The last stream ID checked.
-            current_stream_id: The current position of the to device
-                message stream.
-            limit: The maximum number of messages to retrieve.
+            user_id: The ID of the user to retrieve messages for.
+            device_id: The ID of the device to retrieve to-device messages for.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+            limit: A limit on the number of to-device messages returned.
 
         Returns:
             A tuple containing:
-                * A list of messages for the device.
-                * The max stream token of these messages. There may be more to retrieve
-                  if the given limit was reached.
+                * A list of to-device messages within the given stream id range intended for
+                  the given user / device combo.
+                * The last-processed stream ID. Subsequent calls of this function with the
+                  same device should pass this value as 'from_stream_id'.
         """
-        has_changed = self._device_inbox_stream_cache.has_entity_changed(
-            user_id, last_stream_id
+        (
+            user_id_device_id_to_messages,
+            last_processed_stream_id,
+        ) = await self._get_device_messages(
+            user_ids=[user_id],
+            device_id=device_id,
+            from_stream_id=from_stream_id,
+            to_stream_id=to_stream_id,
+            limit=limit,
         )
-        if not has_changed:
-            return [], current_stream_id
 
-        def get_new_messages_for_device_txn(txn):
-            sql = (
-                "SELECT stream_id, message_json FROM device_inbox"
-                " WHERE user_id = ? AND device_id = ?"
-                " AND ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-                " LIMIT ?"
+        if not user_id_device_id_to_messages:
+            # There were no messages!
+            return [], to_stream_id
+
+        # Extract the messages, no need to return the user and device ID again
+        to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+
+        return to_device_messages, last_processed_stream_id
+
+    async def _get_device_messages(
+        self,
+        user_ids: Collection[str],
+        from_stream_id: int,
+        to_stream_id: int,
+        device_id: Optional[str] = None,
+        limit: Optional[int] = None,
+    ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
+        """
+        Retrieve pending to-device messages for a collection of user devices.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
+        Note that a stream ID can be shared by multiple copies of the same message with
+        different recipient devices. Stream IDs are only unique in the context of a single
+        user ID / device ID pair. Thus, applying a limit (of messages to return) when working
+        with a sliding window of stream IDs is only possible when querying messages of a
+        single user device.
+
+        Finally, note that device IDs are not unique across users.
+
+        Args:
+            user_ids: The user IDs to filter device messages by.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+            device_id: A device ID to query to-device messages for. If not provided, to-device
+                messages from all device IDs for the given user IDs will be queried. May not be
+                provided if `user_ids` contains more than one entry.
+            limit: The maximum number of to-device messages to return. Can only be used when
+                passing a single user ID / device ID tuple.
+
+        Returns:
+            A tuple containing:
+                * A dict of (user_id, device_id) -> list of to-device messages
+                * The last-processed stream ID. If this is less than `to_stream_id`, then
+                    there may be more messages to retrieve. If `limit` is not set, then this
+                    is always equal to 'to_stream_id'.
+        """
+        if not user_ids:
+            logger.warning("No users provided upon querying for device IDs")
+            return {}, to_stream_id
+
+        # Prevent a query for one user's device also retrieving another user's device with
+        # the same device ID (device IDs are not unique across users).
+        if len(user_ids) > 1 and device_id is not None:
+            raise AssertionError(
+                "Programming error: 'device_id' cannot be supplied to "
+                "_get_device_messages when >1 user_id has been provided"
             )
-            txn.execute(
-                sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
+
+        # A limit can only be applied when querying for a single user ID / device ID tuple.
+        # See the docstring of this function for more details.
+        if limit is not None and device_id is None:
+            raise AssertionError(
+                "Programming error: _get_device_messages was passed 'limit' "
+                "without a specific user_id/device_id"
             )
 
-            messages = []
-            stream_pos = current_stream_id
+        user_ids_to_query: Set[str] = set()
+        device_ids_to_query: Set[str] = set()
+
+        # Note that a device ID could be an empty str
+        if device_id is not None:
+            # If a device ID was passed, use it to filter results.
+            # Otherwise, device IDs will be derived from the given collection of user IDs.
+            device_ids_to_query.add(device_id)
+
+        # Determine which users have devices with pending messages
+        for user_id in user_ids:
+            if self._device_inbox_stream_cache.has_entity_changed(
+                user_id, from_stream_id
+            ):
+                # This user has new messages sent to them. Query messages for them
+                user_ids_to_query.add(user_id)
+
+        def get_device_messages_txn(txn: LoggingTransaction):
+            # Build a query to select messages from any of the given devices that
+            # are between the given stream id bounds.
+
+            # If a list of device IDs was not provided, retrieve all devices IDs
+            # for the given users. We explicitly do not query hidden devices, as
+            # hidden devices should not receive to-device messages.
+            # Note that this is more efficient than just dropping `device_id` from the query,
+            # since device_inbox has an index on `(user_id, device_id, stream_id)`
+            if not device_ids_to_query:
+                user_device_dicts = self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="devices",
+                    column="user_id",
+                    iterable=user_ids_to_query,
+                    keyvalues={"user_id": user_id, "hidden": False},
+                    retcols=("device_id",),
+                )
 
-            for row in txn:
-                stream_pos = row[0]
-                messages.append(db_to_json(row[1]))
+                device_ids_to_query.update(
+                    {row["device_id"] for row in user_device_dicts}
+                )
 
-            # If the limit was not reached we know that there's no more data for this
-            # user/device pair up to current_stream_id.
-            if len(messages) < limit:
-                stream_pos = current_stream_id
+            if not device_ids_to_query:
+                # We've ended up with no devices to query.
+                return {}, to_stream_id
 
-            return messages, stream_pos
+            # We include both user IDs and device IDs in this query, as we have an index
+            # (device_inbox_user_stream_id) for them.
+            user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
+                self.database_engine, "user_id", user_ids_to_query
+            )
+            (
+                device_id_many_clause_sql,
+                device_id_many_clause_args,
+            ) = make_in_list_sql_clause(
+                self.database_engine, "device_id", device_ids_to_query
+            )
+
+            sql = f"""
+                SELECT stream_id, user_id, device_id, message_json FROM device_inbox
+                WHERE {user_id_many_clause_sql}
+                AND {device_id_many_clause_sql}
+                AND ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+            """
+            sql_args = (
+                *user_id_many_clause_args,
+                *device_id_many_clause_args,
+                from_stream_id,
+                to_stream_id,
+            )
+
+            # If a limit was provided, limit the data retrieved from the database
+            if limit is not None:
+                sql += "LIMIT ?"
+                sql_args += (limit,)
+
+            txn.execute(sql, sql_args)
+
+            # Create and fill a dictionary of (user ID, device ID) -> list of messages
+            # intended for each device.
+            last_processed_stream_pos = to_stream_id
+            recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
+            for row in txn:
+                last_processed_stream_pos = row[0]
+                recipient_user_id = row[1]
+                recipient_device_id = row[2]
+                message_dict = db_to_json(row[3])
+
+                # Store the device details
+                recipient_device_to_messages.setdefault(
+                    (recipient_user_id, recipient_device_id), []
+                ).append(message_dict)
+
+            if limit is not None and txn.rowcount == limit:
+                # We ended up bumping up against the message limit. There may be more messages
+                # to retrieve. Return what we have, as well as the last stream position that
+                # was processed.
+                #
+                # The caller is expected to set this as the lower (exclusive) bound
+                # for the next query of this device.
+                return recipient_device_to_messages, last_processed_stream_pos
+
+            # The limit was not reached, thus we know that recipient_device_to_messages
+            # contains all to-device messages for the given device and stream id range.
+            #
+            # We return to_stream_id, which the caller should then provide as the lower
+            # (exclusive) bound on the next query of this device.
+            return recipient_device_to_messages, to_stream_id
 
         return await self.db_pool.runInteraction(
-            "get_new_messages_for_device", get_new_messages_for_device_txn
+            "get_device_messages", get_device_messages_txn
         )
 
     @trace
diff --git a/synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql
new file mode 100644
index 0000000000..bbf0af5311
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql
@@ -0,0 +1,21 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column to track what to_device stream id that this application
+-- service has been caught up to.
+
+-- NULL indicates that this appservice has never received any to_device messages. This
+-- can be used, for example, to avoid sending a huge dump of messages at startup.
+ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT;
\ No newline at end of file