summary refs log tree commit diff
path: root/tests/rest/client/test_sendtodevice.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_sendtodevice.py')
-rw-r--r--tests/rest/client/test_sendtodevice.py71
1 files changed, 58 insertions, 13 deletions
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 2f994ad553..5ef501c6d5 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -18,15 +18,39 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
+from parameterized import parameterized_class
 
 from synapse.api.constants import EduTypes
 from synapse.rest import admin
 from synapse.rest.client import login, sendtodevice, sync
+from synapse.types import JsonDict
 
 from tests.unittest import HomeserverTestCase, override_config
 
 
+@parameterized_class(
+    ("sync_endpoint", "experimental_features"),
+    [
+        ("/sync", {}),
+        (
+            "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+            # Enable sliding sync
+            {"msc3575_enabled": True},
+        ),
+    ],
+)
 class SendToDeviceTestCase(HomeserverTestCase):
+    """
+    Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
+
+    Attributes:
+        sync_endpoint: The endpoint under test to use for syncing.
+        experimental_features: The experimental features homeserver config to use.
+    """
+
+    sync_endpoint: str
+    experimental_features: JsonDict
+
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
         sync.register_servlets,
     ]
 
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["experimental_features"] = self.experimental_features
+        return config
+
     def test_user_to_user(self) -> None:
         """A to-device message from one user to another should get delivered"""
 
@@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
         self.assertEqual(chan.code, 200, chan.result)
 
         # check it appears
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         expected_result = {
             "events": [
@@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
         }
         self.assertEqual(channel.json_body["to_device"], expected_result)
 
-        # it should re-appear if we do another sync
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        # it should re-appear if we do another sync because the to-device message is not
+        # deleted until we acknowledge it by sending a `?since=...` parameter in the
+        # next sync request corresponding to the `next_batch` value from the response.
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(channel.json_body["to_device"], expected_result)
 
         # it should *not* appear if we do an incremental sync
         sync_token = channel.json_body["next_batch"]
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
             )
             self.assertEqual(chan.code, 200, chan.result)
 
-        # now sync: we should get two of the three
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        # now sync: we should get two of the three (because burst_count=2)
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         msgs = channel.json_body["to_device"]["events"]
         self.assertEqual(len(msgs), 2)
         for i in range(2):
             self.assertEqual(
                 msgs[i],
-                {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
+                {
+                    "sender": user1,
+                    "type": "m.room_key_request",
+                    "content": {"idx": i},
+                },
             )
         sync_token = channel.json_body["next_batch"]
 
@@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
 
         # ... which should arrive
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         msgs = channel.json_body["to_device"]["events"]
@@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
             )
 
         # now sync: we should get two of the three
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         msgs = channel.json_body["to_device"]["events"]
         self.assertEqual(len(msgs), 2)
@@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
 
         # ... which should arrive
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         msgs = channel.json_body["to_device"]["events"]
@@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
         user2_tok = self.login("u2", "pass", "d2")
 
         # Do an initial sync
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         sync_token = channel.json_body["next_batch"]
 
@@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
             self.assertEqual(chan.code, 200, chan.result)
 
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         messages = channel.json_body.get("to_device", {}).get("events", [])
@@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
         sync_token = channel.json_body["next_batch"]
 
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         messages = channel.json_body.get("to_device", {}).get("events", [])