summary refs log tree commit diff
path: root/tests/rest/client/test_sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_sync.py')
-rw-r--r--tests/rest/client/test_sync.py200
1 files changed, 198 insertions, 2 deletions
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index f7852562b1..304c0d4d3d 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,16 @@ from synapse.api.constants import (
 )
 from synapse.events import EventBase
 from synapse.handlers.sliding_sync import StateValues
-from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
+from synapse.rest.client import (
+    devices,
+    knock,
+    login,
+    read_marker,
+    receipts,
+    room,
+    sendtodevice,
+    sync,
+)
 from synapse.server import HomeServer
 from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
 from synapse.util import Clock
@@ -47,7 +56,7 @@ from tests import unittest
 from tests.federation.transport.test_knocking import (
     KnockingStrippedStateEventHelperMixin,
 )
-from tests.server import TimedOutException
+from tests.server import FakeChannel, TimedOutException
 from tests.test_utils.event_injection import mark_event_as_partial_state
 
 logger = logging.getLogger(__name__)
@@ -3696,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             ],
             channel.json_body["lists"]["foo-list"],
         )
+
+
+class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
+    """Tests for the to-device sliding sync extension"""
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        sendtodevice.register_servlets,
+    ]
+
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        # Enable sliding sync
+        config["experimental_features"] = {"msc3575_enabled": True}
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.sync_endpoint = (
+            "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
+        )
+
+    def _assert_to_device_response(
+        self, channel: FakeChannel, expected_messages: List[JsonDict]
+    ) -> str:
+        """Assert the sliding sync response was successful and has the expected
+        to-device messages.
+
+        Returns the next_batch token from the to-device section.
+        """
+        self.assertEqual(channel.code, 200, channel.json_body)
+        extensions = channel.json_body["extensions"]
+        to_device = extensions["to_device"]
+        self.assertIsInstance(to_device["next_batch"], str)
+        self.assertEqual(to_device["events"], expected_messages)
+
+        return to_device["next_batch"]
+
+    def test_no_data(self) -> None:
+        """Test that enabling to-device extension works, even if there is
+        no-data
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+
+        # We expect no to-device messages
+        self._assert_to_device_response(channel, [])
+
+    def test_data_initial_sync(self) -> None:
+        """Test that we get to-device messages when we don't specify a since
+        token"""
+
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass", "d1")
+        user2_id = self.register_user("u2", "pass")
+        user2_tok = self.login(user2_id, "pass", "d2")
+
+        # Send the to-device message
+        test_msg = {"foo": "bar"}
+        chan = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/sendToDevice/m.test/1234",
+            content={"messages": {user1_id: {"d1": test_msg}}},
+            access_token=user2_tok,
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self._assert_to_device_response(
+            channel,
+            [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
+        )
+
+    def test_data_incremental_sync(self) -> None:
+        """Test that we get to-device messages over incremental syncs"""
+
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass", "d1")
+        user2_id = self.register_user("u2", "pass")
+        user2_tok = self.login(user2_id, "pass", "d2")
+
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        # No to-device messages yet.
+        next_batch = self._assert_to_device_response(channel, [])
+
+        test_msg = {"foo": "bar"}
+        chan = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/sendToDevice/m.test/1234",
+            content={"messages": {user1_id: {"d1": test_msg}}},
+            access_token=user2_tok,
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                        "since": next_batch,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        next_batch = self._assert_to_device_response(
+            channel,
+            [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
+        )
+
+        # The next sliding sync request should not include the to-device
+        # message.
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                        "since": next_batch,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self._assert_to_device_response(channel, [])
+
+        # An initial sliding sync request should not include the to-device
+        # message, as it should have been deleted
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self._assert_to_device_response(channel, [])