summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sliding_sync.py103
-rw-r--r--synapse/rest/client/sync.py17
-rw-r--r--synapse/types/handlers/__init__.py35
-rw-r--r--synapse/types/rest/client/__init__.py48
4 files changed, 193 insertions, 10 deletions
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index bb81ca9d97..818b13621c 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -542,11 +542,15 @@ class SlidingSyncHandler:
 
             rooms[room_id] = room_sync_result
 
+        extensions = await self.get_extensions_response(
+            sync_config=sync_config, to_token=to_token
+        )
+
         return SlidingSyncResult(
             next_pos=to_token,
             lists=lists,
             rooms=rooms,
-            extensions={},
+            extensions=extensions,
         )
 
     async def get_sync_room_ids_for_user(
@@ -1445,3 +1449,100 @@ class SlidingSyncHandler:
             notification_count=0,
             highlight_count=0,
         )
+
+    async def get_extensions_response(
+        self,
+        sync_config: SlidingSyncConfig,
+        to_token: StreamToken,
+    ) -> SlidingSyncResult.Extensions:
+        """Handle extension requests.
+
+        Args:
+            sync_config: Sync configuration
+            to_token: The point in the stream to sync up to.
+        """
+
+        if sync_config.extensions is None:
+            return SlidingSyncResult.Extensions()
+
+        to_device_response = None
+        if sync_config.extensions.to_device:
+            to_device_response = await self.get_to_device_extensions_response(
+                sync_config=sync_config,
+                to_device_request=sync_config.extensions.to_device,
+                to_token=to_token,
+            )
+
+        return SlidingSyncResult.Extensions(to_device=to_device_response)
+
+    async def get_to_device_extensions_response(
+        self,
+        sync_config: SlidingSyncConfig,
+        to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
+        to_token: StreamToken,
+    ) -> SlidingSyncResult.Extensions.ToDeviceExtension:
+        """Handle to-device extension (MSC3885)
+
+        Args:
+            sync_config: Sync configuration
+            to_device_request: The to-device extension from the request
+            to_token: The point in the stream to sync up to.
+        """
+
+        user_id = sync_config.user.to_string()
+        device_id = sync_config.device_id
+
+        # Check that this request has a valid device ID (not all requests have
+        # to belong to a device, and so device_id is None), and that the
+        # extension is enabled.
+        if device_id is None or not to_device_request.enabled:
+            return SlidingSyncResult.Extensions.ToDeviceExtension(
+                next_batch=f"{to_token.to_device_key}",
+                events=[],
+            )
+
+        since_stream_id = 0
+        if to_device_request.since is not None:
+            # We've already validated this is an int.
+            since_stream_id = int(to_device_request.since)
+
+            if to_token.to_device_key < since_stream_id:
+                # The since token is ahead of our current token, so we return an
+                # empty response.
+                logger.warning(
+                    "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
+                    since_stream_id,
+                    to_token.to_device_key,
+                )
+                return SlidingSyncResult.Extensions.ToDeviceExtension(
+                    next_batch=to_device_request.since,
+                    events=[],
+                )
+
+            # Delete everything before the given since token, as we know the
+            # device must have received them.
+            deleted = await self.store.delete_messages_for_device(
+                user_id=user_id,
+                device_id=device_id,
+                up_to_stream_id=since_stream_id,
+            )
+
+            logger.debug(
+                "Deleted %d to-device messages up to %d for %s",
+                deleted,
+                since_stream_id,
+                user_id,
+            )
+
+        messages, stream_id = await self.store.get_messages_for_device(
+            user_id=user_id,
+            device_id=device_id,
+            from_stream_id=since_stream_id,
+            to_stream_id=to_token.to_device_key,
+            limit=min(to_device_request.limit, 100),  # Limit to at most 100 events
+        )
+
+        return SlidingSyncResult.Extensions.ToDeviceExtension(
+            next_batch=f"{stream_id}",
+            events=messages,
+        )
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 13aed1dc85..94d5faf9f7 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet):
         response["rooms"] = await self.encode_rooms(
             requester, sliding_sync_result.rooms
         )
-        response["extensions"] = {}  # TODO: sliding_sync_result.extensions
+        response["extensions"] = await self.encode_extensions(
+            requester, sliding_sync_result.extensions
+        )
 
         return response
 
@@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet):
 
         return serialized_rooms
 
+    async def encode_extensions(
+        self, requester: Requester, extensions: SlidingSyncResult.Extensions
+    ) -> JsonDict:
+        result = {}
+
+        if extensions.to_device is not None:
+            result["to_device"] = {
+                "next_batch": extensions.to_device.next_batch,
+                "events": extensions.to_device.events,
+            }
+
+        return result
+
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     SyncRestServlet(hs).register(http_server)
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 43dcdf20dd..a8a3a8f242 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -18,7 +18,7 @@
 #
 #
 from enum import Enum
-from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
 
 import attr
 from typing_extensions import TypedDict
@@ -252,10 +252,39 @@ class SlidingSyncResult:
         count: int
         ops: List[Operation]
 
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class Extensions:
+        """Responses for extensions
+
+        Attributes:
+            to_device: The to-device extension (MSC3885)
+        """
+
+        @attr.s(slots=True, frozen=True, auto_attribs=True)
+        class ToDeviceExtension:
+            """The to-device extension (MSC3885)
+
+            Attributes:
+                next_batch: The to-device stream token the client should use
+                    to get more results
+                events: A list of to-device messages for the client
+            """
+
+            next_batch: str
+            events: Sequence[JsonMapping]
+
+            def __bool__(self) -> bool:
+                return bool(self.events)
+
+        to_device: Optional[ToDeviceExtension] = None
+
+        def __bool__(self) -> bool:
+            return bool(self.to_device)
+
     next_pos: StreamToken
     lists: Dict[str, SlidingWindowList]
     rooms: Dict[str, RoomResult]
-    extensions: JsonMapping
+    extensions: Extensions
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -271,5 +300,5 @@ class SlidingSyncResult:
             next_pos=next_pos,
             lists={},
             rooms={},
-            extensions={},
+            extensions=SlidingSyncResult.Extensions(),
         )
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index 55f6b44053..1e8fe76c99 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel):
     class RoomSubscription(CommonRoomParameters):
         pass
 
-    class Extension(RequestBodyModel):
-        enabled: Optional[StrictBool] = False
-        lists: Optional[List[StrictStr]] = None
-        rooms: Optional[List[StrictStr]] = None
+    class Extensions(RequestBodyModel):
+        """The extensions section of the request.
+
+        Extensions MUST have an `enabled` flag which defaults to `false`. If a client
+        sends an unknown extension name, the server MUST ignore it (or else backwards
+        compatibility between clients and servers is broken when a newer client tries to
+        communicate with an older server).
+        """
+
+        class ToDeviceExtension(RequestBodyModel):
+            """The to-device extension (MSC3885)
+
+            Attributes:
+                enabled
+                limit: Maximum number of to-device messages to return
+                since: The `next_batch` from the previous sync response
+            """
+
+            enabled: Optional[StrictBool] = False
+            limit: StrictInt = 100
+            since: Optional[StrictStr] = None
+
+            @validator("since")
+            def since_token_check(
+                cls, value: Optional[StrictStr]
+            ) -> Optional[StrictStr]:
+                # `since` comes in as an opaque string token but we know that it's just
+                # an integer representing the position in the device inbox stream. We
+                # want to pre-validate it to make sure it works fine in downstream code.
+                if value is None:
+                    return value
+
+                try:
+                    int(value)
+                except ValueError:
+                    raise ValueError(
+                        "'extensions.to_device.since' is invalid (should look like an int)"
+                    )
+
+                return value
+
+        to_device: Optional[ToDeviceExtension] = None
 
     # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
     if TYPE_CHECKING:
@@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel):
     else:
         lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None  # type: ignore[valid-type]
     room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
-    extensions: Optional[Dict[StrictStr, Extension]] = None
+    extensions: Optional[Extensions] = None
 
     @validator("lists")
     def lists_length_check(