diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index f7852562b1..304c0d4d3d 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,16 @@ from synapse.api.constants import (
)
from synapse.events import EventBase
from synapse.handlers.sliding_sync import StateValues
-from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
+from synapse.rest.client import (
+ devices,
+ knock,
+ login,
+ read_marker,
+ receipts,
+ room,
+ sendtodevice,
+ sync,
+)
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
from synapse.util import Clock
@@ -47,7 +56,7 @@ from tests import unittest
from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
-from tests.server import TimedOutException
+from tests.server import FakeChannel, TimedOutException
from tests.test_utils.event_injection import mark_event_as_partial_state
logger = logging.getLogger(__name__)
@@ -3696,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
],
channel.json_body["lists"]["foo-list"],
)
+
+
+class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
+ """Tests for the to-device sliding sync extension"""
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ sendtodevice.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ # Enable sliding sync
+ config["experimental_features"] = {"msc3575_enabled": True}
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.sync_endpoint = (
+ "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
+ )
+
+ def _assert_to_device_response(
+ self, channel: FakeChannel, expected_messages: List[JsonDict]
+ ) -> str:
+ """Assert the sliding sync response was successful and has the expected
+ to-device messages.
+
+ Returns the next_batch token from the to-device section.
+ """
+ self.assertEqual(channel.code, 200, channel.json_body)
+ extensions = channel.json_body["extensions"]
+ to_device = extensions["to_device"]
+ self.assertIsInstance(to_device["next_batch"], str)
+ self.assertEqual(to_device["events"], expected_messages)
+
+ return to_device["next_batch"]
+
+ def test_no_data(self) -> None:
+ """Test that enabling to-device extension works, even if there is
+ no-data
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+
+ # We expect no to-device messages
+ self._assert_to_device_response(channel, [])
+
+ def test_data_initial_sync(self) -> None:
+ """Test that we get to-device messages when we don't specify a since
+ token"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass", "d1")
+ user2_id = self.register_user("u2", "pass")
+ user2_tok = self.login(user2_id, "pass", "d2")
+
+ # Send the to-device message
+ test_msg = {"foo": "bar"}
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.test/1234",
+ content={"messages": {user1_id: {"d1": test_msg}}},
+ access_token=user2_tok,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self._assert_to_device_response(
+ channel,
+ [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
+ )
+
+ def test_data_incremental_sync(self) -> None:
+ """Test that we get to-device messages over incremental syncs"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass", "d1")
+ user2_id = self.register_user("u2", "pass")
+ user2_tok = self.login(user2_id, "pass", "d2")
+
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ # No to-device messages yet.
+ next_batch = self._assert_to_device_response(channel, [])
+
+ test_msg = {"foo": "bar"}
+ chan = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/sendToDevice/m.test/1234",
+ content={"messages": {user1_id: {"d1": test_msg}}},
+ access_token=user2_tok,
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ "since": next_batch,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ next_batch = self._assert_to_device_response(
+ channel,
+ [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
+ )
+
+ # The next sliding sync request should not include the to-device
+ # message.
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ "since": next_batch,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self._assert_to_device_response(channel, [])
+
+ # An initial sliding sync request should not include the to-device
+ # message, as it should have been deleted
+ channel = self.make_request(
+ "POST",
+ self.sync_endpoint,
+ {
+ "lists": {},
+ "extensions": {
+ "to_device": {
+ "enabled": True,
+ }
+ },
+ },
+ access_token=user1_tok,
+ )
+ self._assert_to_device_response(channel, [])
|