diff --git a/changelog.d/17416.feature b/changelog.d/17416.feature
new file mode 100644
index 0000000000..1d119cf48f
--- /dev/null
+++ b/changelog.d/17416.feature
@@ -0,0 +1 @@
+Add to-device extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index bb81ca9d97..818b13621c 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -542,11 +542,15 @@ class SlidingSyncHandler:
rooms[room_id] = room_sync_result
+ extensions = await self.get_extensions_response(
+ sync_config=sync_config, to_token=to_token
+ )
+
return SlidingSyncResult(
next_pos=to_token,
lists=lists,
rooms=rooms,
- extensions={},
+ extensions=extensions,
)
async def get_sync_room_ids_for_user(
@@ -1445,3 +1449,100 @@ class SlidingSyncHandler:
notification_count=0,
highlight_count=0,
)
+
+ async def get_extensions_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ to_token: StreamToken,
+ ) -> SlidingSyncResult.Extensions:
+ """Handle extension requests.
+
+ Args:
+ sync_config: Sync configuration
+ to_token: The point in the stream to sync up to.
+ """
+
+ if sync_config.extensions is None:
+ return SlidingSyncResult.Extensions()
+
+ to_device_response = None
+ if sync_config.extensions.to_device:
+ to_device_response = await self.get_to_device_extensions_response(
+ sync_config=sync_config,
+ to_device_request=sync_config.extensions.to_device,
+ to_token=to_token,
+ )
+
+ return SlidingSyncResult.Extensions(to_device=to_device_response)
+
+ async def get_to_device_extensions_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
+ to_token: StreamToken,
+ ) -> SlidingSyncResult.Extensions.ToDeviceExtension:
+ """Handle to-device extension (MSC3885)
+
+ Args:
+ sync_config: Sync configuration
+ to_device_request: The to-device extension from the request
+ to_token: The point in the stream to sync up to.
+ """
+
+ user_id = sync_config.user.to_string()
+ device_id = sync_config.device_id
+
+ # Check that this request has a valid device ID (not all requests have
+ # to belong to a device, and so device_id is None), and that the
+ # extension is enabled.
+ if device_id is None or not to_device_request.enabled:
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=f"{to_token.to_device_key}",
+ events=[],
+ )
+
+ since_stream_id = 0
+ if to_device_request.since is not None:
+ # We've already validated this is an int.
+ since_stream_id = int(to_device_request.since)
+
+ if to_token.to_device_key < since_stream_id:
+ # The since token is ahead of our current token, so we return an
+ # empty response.
+ logger.warning(
+ "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
+ since_stream_id,
+ to_token.to_device_key,
+ )
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=to_device_request.since,
+ events=[],
+ )
+
+ # Delete everything before the given since token, as we know the
+ # device must have received them.
+ deleted = await self.store.delete_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ up_to_stream_id=since_stream_id,
+ )
+
+ logger.debug(
+ "Deleted %d to-device messages up to %d for %s",
+ deleted,
+ since_stream_id,
+ user_id,
+ )
+
+ messages, stream_id = await self.store.get_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ from_stream_id=since_stream_id,
+ to_stream_id=to_token.to_device_key,
+ limit=min(to_device_request.limit, 100), # Limit to at most 100 events
+ )
+
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=f"{stream_id}",
+ events=messages,
+ )
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 13aed1dc85..94d5faf9f7 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet):
response["rooms"] = await self.encode_rooms(
requester, sliding_sync_result.rooms
)
- response["extensions"] = {} # TODO: sliding_sync_result.extensions
+ response["extensions"] = await self.encode_extensions(
+ requester, sliding_sync_result.extensions
+ )
return response
@@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet):
return serialized_rooms
+ async def encode_extensions(
+ self, requester: Requester, extensions: SlidingSyncResult.Extensions
+ ) -> JsonDict:
+ result = {}
+
+ if extensions.to_device is not None:
+ result["to_device"] = {
+ "next_batch": extensions.to_device.next_batch,
+ "events": extensions.to_device.events,
+ }
+
+ return result
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 43dcdf20dd..a8a3a8f242 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -18,7 +18,7 @@
#
#
from enum import Enum
-from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
import attr
from typing_extensions import TypedDict
@@ -252,10 +252,39 @@ class SlidingSyncResult:
count: int
ops: List[Operation]
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class Extensions:
+ """Responses for extensions
+
+ Attributes:
+ to_device: The to-device extension (MSC3885)
+ """
+
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class ToDeviceExtension:
+ """The to-device extension (MSC3885)
+
+ Attributes:
+ next_batch: The to-device stream token the client should use
+ to get more results
+ events: A list of to-device messages for the client
+ """
+
+ next_batch: str
+ events: Sequence[JsonMapping]
+
+ def __bool__(self) -> bool:
+ return bool(self.events)
+
+ to_device: Optional[ToDeviceExtension] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.to_device)
+
next_pos: StreamToken
lists: Dict[str, SlidingWindowList]
rooms: Dict[str, RoomResult]
- extensions: JsonMapping
+ extensions: Extensions
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -271,5 +300,5 @@ class SlidingSyncResult:
next_pos=next_pos,
lists={},
rooms={},
- extensions={},
+ extensions=SlidingSyncResult.Extensions(),
)
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index 55f6b44053..1e8fe76c99 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel):
class RoomSubscription(CommonRoomParameters):
pass
- class Extension(RequestBodyModel):
- enabled: Optional[StrictBool] = False
- lists: Optional[List[StrictStr]] = None
- rooms: Optional[List[StrictStr]] = None
+ class Extensions(RequestBodyModel):
+ """The extensions section of the request.
+
+ Extensions MUST have an `enabled` flag which defaults to `false`. If a client
+ sends an unknown extension name, the server MUST ignore it (or else backwards
+ compatibility between clients and servers is broken when a newer client tries to
+ communicate with an older server).
+ """
+
+ class ToDeviceExtension(RequestBodyModel):
+ """The to-device extension (MSC3885)
+
+ Attributes:
+ enabled
+ limit: Maximum number of to-device messages to return
+ since: The `next_batch` from the previous sync response
+ """
+
+ enabled: Optional[StrictBool] = False
+ limit: StrictInt = 100
+ since: Optional[StrictStr] = None
+
+ @validator("since")
+ def since_token_check(
+ cls, value: Optional[StrictStr]
+ ) -> Optional[StrictStr]:
+ # `since` comes in as an opaque string token but we know that it's just
+ # an integer representing the position in the device inbox stream. We
+ # want to pre-validate it to make sure it works fine in downstream code.
+ if value is None:
+ return value
+
+ try:
+ int(value)
+ except ValueError:
+ raise ValueError(
+ "'extensions.to_device.since' is invalid (should look like an int)"
+ )
+
+ return value
+
+ to_device: Optional[ToDeviceExtension] = None
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING:
@@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel):
else:
lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
- extensions: Optional[Dict[StrictStr, Extension]] = None
+ extensions: Optional[Extensions] = None
@validator("lists")
def lists_length_check(
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index f7852562b1..304c0d4d3d 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,16 @@ from synapse.api.constants import (
)
from synapse.events import EventBase
from synapse.handlers.sliding_sync import StateValues
-from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
+from synapse.rest.client import (
+ devices,
+ knock,
+ login,
+ read_marker,
+ receipts,
+ room,
+ sendtodevice,
+ sync,
+)
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
from synapse.util import Clock
@@ -47,7 +56,7 @@ from tests import unittest
from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
-from tests.server import TimedOutException
+from tests.server import FakeChannel, TimedOutException
from tests.test_utils.event_injection import mark_event_as_partial_state
logger = logging.getLogger(__name__)
@@ -3696,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
],
channel.json_body["lists"]["foo-list"],
)
+
+
+class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
+ """Tests for the to-device sliding sync extension"""
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ sendtodevice.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ # Enable sliding sync
+ config["experimental_features"] = {"msc3575_enabled": True}
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.sync_endpoint = (
+ "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
+ )
+
+ def _assert_to_device_response(
+ self, channel: FakeChannel, expected_messages: List[JsonDict]
+ ) -> str:
+ """Assert the sliding sync response was successful and has the expected
+ to-device messages.
+
+ Returns the next_batch token from the to-device section.
+ """
+ self.assertEqual(channel.code, 200, channel.json_body)
+ extensions = channel.json_body["extensions"]
+ to_device = extensions["to_device"]
+ self.assertIsInstance(to_device["next_batch"], str)
+ self.assertEqual(to_device["events"], expected_messages)
+
+ return to_device["next_batch"]
+
+ def test_no_data(self) -> None:
+ """Test that enabling to-device extension works, even if there is
+ no-data
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+
+ # We expect no to-device messages
+ self._assert_to_device_response(channel, [])
+
+ def test_data_initial_sync(self) -> None:
+ """Test that we get to-device messages when we don't specify a since
+ token"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass", "d1")
+ user2_id = self.register_user("u2", "pass")
+ user2_tok = self.login(user2_id, "pass", "d2")
+
+ # Send the to-device message
+ test_msg = {"foo": "bar"}
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.test/1234",
+ content={"messages": {user1_id: {"d1": test_msg}}},
+ access_token=user2_tok,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self._assert_to_device_response(
+ channel,
+ [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
+ )
+
+ def test_data_incremental_sync(self) -> None:
+ """Test that we get to-device messages over incremental syncs"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass", "d1")
+ user2_id = self.register_user("u2", "pass")
+ user2_tok = self.login(user2_id, "pass", "d2")
+
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ # No to-device messages yet.
+ next_batch = self._assert_to_device_response(channel, [])
+
+ test_msg = {"foo": "bar"}
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.test/1234",
+ content={"messages": {user1_id: {"d1": test_msg}}},
+ access_token=user2_tok,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ "since": next_batch,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ next_batch = self._assert_to_device_response(
+ channel,
+ [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
+ )
+
+ # The next sliding sync request should not include the to-device
+ # message.
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ "since": next_batch,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self._assert_to_device_response(channel, [])
+
+ # An initial sliding sync request should not include the to-device
+ # message, as it should have been deleted
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self._assert_to_device_response(channel, [])
|