summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-07-10 11:58:42 +0100
committerGitHub <noreply@github.com>2024-07-10 11:58:42 +0100
commit4ca13ce0dd6d1dc931cfde7e06191200ca0ec066 (patch)
tree1efce62ab2279a8881926e32275ecbb40f4725d7
parentMerge branch 'release-v1.111' into develop (diff)
downloadsynapse-4ca13ce0dd6d1dc931cfde7e06191200ca0ec066.tar.xz
Handle to-device extensions to Sliding Sync (#17416)
Implements MSC3885

---------

Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
-rw-r--r--changelog.d/17416.feature1
-rw-r--r--synapse/handlers/sliding_sync.py103
-rw-r--r--synapse/rest/client/sync.py17
-rw-r--r--synapse/types/handlers/__init__.py35
-rw-r--r--synapse/types/rest/client/__init__.py48
-rw-r--r--tests/rest/client/test_sync.py200
6 files changed, 392 insertions, 12 deletions
diff --git a/changelog.d/17416.feature b/changelog.d/17416.feature
new file mode 100644
index 0000000000..1d119cf48f
--- /dev/null
+++ b/changelog.d/17416.feature
@@ -0,0 +1 @@
+Add to-device extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index bb81ca9d97..818b13621c 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -542,11 +542,15 @@ class SlidingSyncHandler:
 
             rooms[room_id] = room_sync_result
 
+        extensions = await self.get_extensions_response(
+            sync_config=sync_config, to_token=to_token
+        )
+
         return SlidingSyncResult(
             next_pos=to_token,
             lists=lists,
             rooms=rooms,
-            extensions={},
+            extensions=extensions,
         )
 
     async def get_sync_room_ids_for_user(
@@ -1445,3 +1449,100 @@ class SlidingSyncHandler:
             notification_count=0,
             highlight_count=0,
         )
+
+    async def get_extensions_response(
+        self,
+        sync_config: SlidingSyncConfig,
+        to_token: StreamToken,
+    ) -> SlidingSyncResult.Extensions:
+        """Handle extension requests.
+
+        Args:
+            sync_config: Sync configuration
+            to_token: The point in the stream to sync up to.
+        """
+
+        if sync_config.extensions is None:
+            return SlidingSyncResult.Extensions()
+
+        to_device_response = None
+        if sync_config.extensions.to_device:
+            to_device_response = await self.get_to_device_extensions_response(
+                sync_config=sync_config,
+                to_device_request=sync_config.extensions.to_device,
+                to_token=to_token,
+            )
+
+        return SlidingSyncResult.Extensions(to_device=to_device_response)
+
+    async def get_to_device_extensions_response(
+        self,
+        sync_config: SlidingSyncConfig,
+        to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
+        to_token: StreamToken,
+    ) -> SlidingSyncResult.Extensions.ToDeviceExtension:
+        """Handle to-device extension (MSC3885)
+
+        Args:
+            sync_config: Sync configuration
+            to_device_request: The to-device extension from the request
+            to_token: The point in the stream to sync up to.
+        """
+
+        user_id = sync_config.user.to_string()
+        device_id = sync_config.device_id
+
+        # Check that this request has a valid device ID (not all requests have
+        # to belong to a device, and so device_id is None), and that the
+        # extension is enabled.
+        if device_id is None or not to_device_request.enabled:
+            return SlidingSyncResult.Extensions.ToDeviceExtension(
+                next_batch=f"{to_token.to_device_key}",
+                events=[],
+            )
+
+        since_stream_id = 0
+        if to_device_request.since is not None:
+            # We've already validated this is an int.
+            since_stream_id = int(to_device_request.since)
+
+            if to_token.to_device_key < since_stream_id:
+                # The since token is ahead of our current token, so we return an
+                # empty response.
+                logger.warning(
+                    "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
+                    since_stream_id,
+                    to_token.to_device_key,
+                )
+                return SlidingSyncResult.Extensions.ToDeviceExtension(
+                    next_batch=to_device_request.since,
+                    events=[],
+                )
+
+            # Delete everything before the given since token, as we know the
+            # device must have received them.
+            deleted = await self.store.delete_messages_for_device(
+                user_id=user_id,
+                device_id=device_id,
+                up_to_stream_id=since_stream_id,
+            )
+
+            logger.debug(
+                "Deleted %d to-device messages up to %d for %s",
+                deleted,
+                since_stream_id,
+                user_id,
+            )
+
+        messages, stream_id = await self.store.get_messages_for_device(
+            user_id=user_id,
+            device_id=device_id,
+            from_stream_id=since_stream_id,
+            to_stream_id=to_token.to_device_key,
+            limit=min(to_device_request.limit, 100),  # Limit to at most 100 events
+        )
+
+        return SlidingSyncResult.Extensions.ToDeviceExtension(
+            next_batch=f"{stream_id}",
+            events=messages,
+        )
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 13aed1dc85..94d5faf9f7 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet):
         response["rooms"] = await self.encode_rooms(
             requester, sliding_sync_result.rooms
         )
-        response["extensions"] = {}  # TODO: sliding_sync_result.extensions
+        response["extensions"] = await self.encode_extensions(
+            requester, sliding_sync_result.extensions
+        )
 
         return response
 
@@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet):
 
         return serialized_rooms
 
+    async def encode_extensions(
+        self, requester: Requester, extensions: SlidingSyncResult.Extensions
+    ) -> JsonDict:
+        result = {}
+
+        if extensions.to_device is not None:
+            result["to_device"] = {
+                "next_batch": extensions.to_device.next_batch,
+                "events": extensions.to_device.events,
+            }
+
+        return result
+
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     SyncRestServlet(hs).register(http_server)
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 43dcdf20dd..a8a3a8f242 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -18,7 +18,7 @@
 #
 #
 from enum import Enum
-from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
 
 import attr
 from typing_extensions import TypedDict
@@ -252,10 +252,39 @@ class SlidingSyncResult:
         count: int
         ops: List[Operation]
 
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class Extensions:
+        """Responses for extensions
+
+        Attributes:
+            to_device: The to-device extension (MSC3885)
+        """
+
+        @attr.s(slots=True, frozen=True, auto_attribs=True)
+        class ToDeviceExtension:
+            """The to-device extension (MSC3885)
+
+            Attributes:
+                next_batch: The to-device stream token the client should use
+                    to get more results
+                events: A list of to-device messages for the client
+            """
+
+            next_batch: str
+            events: Sequence[JsonMapping]
+
+            def __bool__(self) -> bool:
+                return bool(self.events)
+
+        to_device: Optional[ToDeviceExtension] = None
+
+        def __bool__(self) -> bool:
+            return bool(self.to_device)
+
     next_pos: StreamToken
     lists: Dict[str, SlidingWindowList]
     rooms: Dict[str, RoomResult]
-    extensions: JsonMapping
+    extensions: Extensions
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -271,5 +300,5 @@ class SlidingSyncResult:
             next_pos=next_pos,
             lists={},
             rooms={},
-            extensions={},
+            extensions=SlidingSyncResult.Extensions(),
         )
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index 55f6b44053..1e8fe76c99 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel):
     class RoomSubscription(CommonRoomParameters):
         pass
 
-    class Extension(RequestBodyModel):
-        enabled: Optional[StrictBool] = False
-        lists: Optional[List[StrictStr]] = None
-        rooms: Optional[List[StrictStr]] = None
+    class Extensions(RequestBodyModel):
+        """The extensions section of the request.
+
+        Extensions MUST have an `enabled` flag which defaults to `false`. If a client
+        sends an unknown extension name, the server MUST ignore it (or else backwards
+        compatibility between clients and servers is broken when a newer client tries to
+        communicate with an older server).
+        """
+
+        class ToDeviceExtension(RequestBodyModel):
+            """The to-device extension (MSC3885)
+
+            Attributes:
+                enabled
+                limit: Maximum number of to-device messages to return
+                since: The `next_batch` from the previous sync response
+            """
+
+            enabled: Optional[StrictBool] = False
+            limit: StrictInt = 100
+            since: Optional[StrictStr] = None
+
+            @validator("since")
+            def since_token_check(
+                cls, value: Optional[StrictStr]
+            ) -> Optional[StrictStr]:
+                # `since` comes in as an opaque string token but we know that it's just
+                # an integer representing the position in the device inbox stream. We
+                # want to pre-validate it to make sure it works fine in downstream code.
+                if value is None:
+                    return value
+
+                try:
+                    int(value)
+                except ValueError:
+                    raise ValueError(
+                        "'extensions.to_device.since' is invalid (should look like an int)"
+                    )
+
+                return value
+
+        to_device: Optional[ToDeviceExtension] = None
 
     # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
     if TYPE_CHECKING:
@@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel):
     else:
         lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None  # type: ignore[valid-type]
     room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
-    extensions: Optional[Dict[StrictStr, Extension]] = None
+    extensions: Optional[Extensions] = None
 
     @validator("lists")
     def lists_length_check(
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index f7852562b1..304c0d4d3d 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,16 @@ from synapse.api.constants import (
 )
 from synapse.events import EventBase
 from synapse.handlers.sliding_sync import StateValues
-from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
+from synapse.rest.client import (
+    devices,
+    knock,
+    login,
+    read_marker,
+    receipts,
+    room,
+    sendtodevice,
+    sync,
+)
 from synapse.server import HomeServer
 from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
 from synapse.util import Clock
@@ -47,7 +56,7 @@ from tests import unittest
 from tests.federation.transport.test_knocking import (
     KnockingStrippedStateEventHelperMixin,
 )
-from tests.server import TimedOutException
+from tests.server import FakeChannel, TimedOutException
 from tests.test_utils.event_injection import mark_event_as_partial_state
 
 logger = logging.getLogger(__name__)
@@ -3696,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             ],
             channel.json_body["lists"]["foo-list"],
         )
+
+
+class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
+    """Tests for the to-device sliding sync extension"""
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        sendtodevice.register_servlets,
+    ]
+
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        # Enable sliding sync
+        config["experimental_features"] = {"msc3575_enabled": True}
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.sync_endpoint = (
+            "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
+        )
+
+    def _assert_to_device_response(
+        self, channel: FakeChannel, expected_messages: List[JsonDict]
+    ) -> str:
+        """Assert the sliding sync response was successful and has the expected
+        to-device messages.
+
+        Returns the next_batch token from the to-device section.
+        """
+        self.assertEqual(channel.code, 200, channel.json_body)
+        extensions = channel.json_body["extensions"]
+        to_device = extensions["to_device"]
+        self.assertIsInstance(to_device["next_batch"], str)
+        self.assertEqual(to_device["events"], expected_messages)
+
+        return to_device["next_batch"]
+
+    def test_no_data(self) -> None:
+        """Test that enabling to-device extension works, even if there is
+        no-data
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+
+        # We expect no to-device messages
+        self._assert_to_device_response(channel, [])
+
+    def test_data_initial_sync(self) -> None:
+        """Test that we get to-device messages when we don't specify a since
+        token"""
+
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass", "d1")
+        user2_id = self.register_user("u2", "pass")
+        user2_tok = self.login(user2_id, "pass", "d2")
+
+        # Send the to-device message
+        test_msg = {"foo": "bar"}
+        chan = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/sendToDevice/m.test/1234",
+            content={"messages": {user1_id: {"d1": test_msg}}},
+            access_token=user2_tok,
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self._assert_to_device_response(
+            channel,
+            [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
+        )
+
+    def test_data_incremental_sync(self) -> None:
+        """Test that we get to-device messages over incremental syncs"""
+
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass", "d1")
+        user2_id = self.register_user("u2", "pass")
+        user2_tok = self.login(user2_id, "pass", "d2")
+
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        # No to-device messages yet.
+        next_batch = self._assert_to_device_response(channel, [])
+
+        test_msg = {"foo": "bar"}
+        chan = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/sendToDevice/m.test/1234",
+            content={"messages": {user1_id: {"d1": test_msg}}},
+            access_token=user2_tok,
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                        "since": next_batch,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        next_batch = self._assert_to_device_response(
+            channel,
+            [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
+        )
+
+        # The next sliding sync request should not include the to-device
+        # message.
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                        "since": next_batch,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self._assert_to_device_response(channel, [])
+
+        # An initial sliding sync request should not include the to-device
+        # message, as it should have been deleted
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self._assert_to_device_response(channel, [])