summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/device.py4
-rw-r--r--synapse/handlers/devicemessage.py108
2 files changed, 109 insertions, 3 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d73d9dca08..f3a713f5fa 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -653,6 +653,7 @@ class DeviceHandler(DeviceWorkerHandler):
     async def store_dehydrated_device(
         self,
         user_id: str,
+        device_id: Optional[str],
         device_data: JsonDict,
         initial_device_display_name: Optional[str] = None,
     ) -> str:
@@ -661,6 +662,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         Args:
             user_id: the user that we are storing the device for
+            device_id: device id supplied by client
             device_data: the dehydrated device information
             initial_device_display_name: The display name to use for the device
         Returns:
@@ -668,7 +670,7 @@ class DeviceHandler(DeviceWorkerHandler):
         """
         device_id = await self.check_device_registered(
             user_id,
-            None,
+            device_id,
             initial_device_display_name,
         )
         old_device_id = await self.store.store_dehydrated_device(
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 3caf9b31cc..15e94a03cb 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -13,10 +13,11 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict
+from http import HTTPStatus
+from typing import TYPE_CHECKING, Any, Dict, Optional
 
 from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
@@ -48,6 +49,9 @@ class DeviceMessageHandler:
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.is_mine = hs.is_mine
+        if hs.config.experimental.msc3814_enabled:
+            self.event_sources = hs.get_event_sources()
+            self.device_handler = hs.get_device_handler()
 
         # We only need to poke the federation sender explicitly if its on the
         # same instance. Other federation sender instances will get notified by
@@ -303,3 +307,103 @@ class DeviceMessageHandler:
                 # Enqueue a new federation transaction to send the new
                 # device messages to each remote destination.
                 self.federation_sender.send_device_messages(destination)
+
+    async def get_events_for_dehydrated_device(
+        self,
+        requester: Requester,
+        device_id: str,
+        since_token: Optional[str],
+        limit: int,
+    ) -> JsonDict:
+        """Fetches up to `limit` events sent to `device_id` starting from `since_token`
+        and returns the new since token. If there are no more messages, returns an empty
+        array.
+
+        Args:
+            requester: the user requesting the messages
+            device_id: ID of the dehydrated device
+            since_token: stream id to start from when fetching messages
+            limit: the number of messages to fetch
+        Returns:
+            A dict containing the to-device messages, as well as a token that the client
+            can provide in the next call to fetch the next batch of messages
+        """
+
+        user_id = requester.user.to_string()
+
+        # only allow fetching messages for the dehydrated device id currently associated
+        # with the user
+        dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
+        if dehydrated_device is None:
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "No dehydrated device exists",
+                Codes.FORBIDDEN,
+            )
+
+        dehydrated_device_id, _ = dehydrated_device
+        if device_id != dehydrated_device_id:
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "You may only fetch messages for your dehydrated device",
+                Codes.FORBIDDEN,
+            )
+
+        since_stream_id = 0
+        if since_token:
+            if not since_token.startswith("d"):
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "from parameter %r has an invalid format" % (since_token,),
+                    errcode=Codes.INVALID_PARAM,
+                )
+
+            try:
+                since_stream_id = int(since_token[1:])
+            except Exception:
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "from parameter %r has an invalid format" % (since_token,),
+                    errcode=Codes.INVALID_PARAM,
+                )
+
+            # if we have a since token, delete any to-device messages before that token
+            # (since we now know that the device has received them)
+            deleted = await self.store.delete_messages_for_device(
+                user_id, device_id, since_stream_id
+            )
+            logger.debug(
+                "Deleted %d to-device messages up to %d for user_id %s device_id %s",
+                deleted,
+                since_stream_id,
+                user_id,
+                device_id,
+            )
+
+        to_token = self.event_sources.get_current_token().to_device_key
+
+        messages, stream_id = await self.store.get_messages_for_device(
+            user_id, device_id, since_stream_id, to_token, limit
+        )
+
+        for message in messages:
+            # Remove the message id before sending to client
+            message_id = message.pop("message_id", None)
+            if message_id:
+                set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id)
+
+        logger.debug(
+            "Returning %d to-device messages between %d and %d (current token: %d) for "
+            "dehydrated device %s, user_id %s",
+            len(messages),
+            since_stream_id,
+            stream_id,
+            to_token,
+            device_id,
+            user_id,
+        )
+
+        return {
+            "events": messages,
+            "next_batch": f"d{stream_id}",
+        }