diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 2f994ad553..5ef501c6d5 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -18,15 +18,39 @@
# [This file includes modifications made by New Vector Limited]
#
#
+from parameterized import parameterized_class
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
+from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase, override_config
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
class SendToDeviceTestCase(HomeserverTestCase):
+ """
+ Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
@@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# check it appears
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
@@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
}
self.assertEqual(channel.json_body["to_device"], expected_result)
- # it should re-appear if we do another sync
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ # it should re-appear if we do another sync because the to-device message is not
+ # deleted until we acknowledge it by sending a `?since=...` parameter in the
+ # next sync request corresponding to the `next_batch` value from the response.
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 200, chan.result)
- # now sync: we should get two of the three
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ # now sync: we should get two of the three (because burst_count=2)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
- {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
+ {
+ "sender": user1,
+ "type": "m.room_key_request",
+ "content": {"idx": i},
+ },
)
sync_token = channel.json_body["next_batch"]
@@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
# now sync: we should get two of the three
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
@@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
- "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
@@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
|