diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 3caf9b31cc..15e94a03cb 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -13,10 +13,11 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict
+from http import HTTPStatus
+from typing import TYPE_CHECKING, Any, Dict, Optional
from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@@ -48,6 +49,9 @@ class DeviceMessageHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.is_mine = hs.is_mine
+ if hs.config.experimental.msc3814_enabled:
+ self.event_sources = hs.get_event_sources()
+ self.device_handler = hs.get_device_handler()
# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
@@ -303,3 +307,103 @@ class DeviceMessageHandler:
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation_sender.send_device_messages(destination)
+
+ async def get_events_for_dehydrated_device(
+ self,
+ requester: Requester,
+ device_id: str,
+ since_token: Optional[str],
+ limit: int,
+ ) -> JsonDict:
+ """Fetches up to `limit` events sent to `device_id` starting from `since_token`
+ and returns the new since token. If there are no more messages, returns an empty
+ array.
+
+ Args:
+ requester: the user requesting the messages
+ device_id: ID of the dehydrated device
+ since_token: stream id to start from when fetching messages
+ limit: the number of messages to fetch
+ Returns:
+ A dict containing the to-device messages, as well as a token that the client
+ can provide in the next call to fetch the next batch of messages
+ """
+
+ user_id = requester.user.to_string()
+
+ # only allow fetching messages for the dehydrated device id currently associated
+ # with the user
+ dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
+ if dehydrated_device is None:
+ raise SynapseError(
+ HTTPStatus.FORBIDDEN,
+ "No dehydrated device exists",
+ Codes.FORBIDDEN,
+ )
+
+ dehydrated_device_id, _ = dehydrated_device
+ if device_id != dehydrated_device_id:
+ raise SynapseError(
+ HTTPStatus.FORBIDDEN,
+ "You may only fetch messages for your dehydrated device",
+ Codes.FORBIDDEN,
+ )
+
+ since_stream_id = 0
+ if since_token:
+ if not since_token.startswith("d"):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "from parameter %r has an invalid format" % (since_token,),
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ try:
+ since_stream_id = int(since_token[1:])
+ except Exception:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "from parameter %r has an invalid format" % (since_token,),
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # if we have a since token, delete any to-device messages before that token
+ # (since we now know that the device has received them)
+ deleted = await self.store.delete_messages_for_device(
+ user_id, device_id, since_stream_id
+ )
+ logger.debug(
+ "Deleted %d to-device messages up to %d for user_id %s device_id %s",
+ deleted,
+ since_stream_id,
+ user_id,
+ device_id,
+ )
+
+ to_token = self.event_sources.get_current_token().to_device_key
+
+ messages, stream_id = await self.store.get_messages_for_device(
+ user_id, device_id, since_stream_id, to_token, limit
+ )
+
+ for message in messages:
+ # Remove the message id before sending to client
+ message_id = message.pop("message_id", None)
+ if message_id:
+ set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id)
+
+ logger.debug(
+ "Returning %d to-device messages between %d and %d (current token: %d) for "
+ "dehydrated device %s, user_id %s",
+ len(messages),
+ since_stream_id,
+ stream_id,
+ to_token,
+ device_id,
+ user_id,
+ )
+
+ return {
+ "events": messages,
+ "next_batch": f"d{stream_id}",
+ }
|