diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ee48f9e546..0e34ee208a 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,13 +19,14 @@ from unittest import mock
from twisted.test.proto_helpers import MemoryReactor
+import synapse
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError
from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@@ -37,6 +38,11 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ ]
+
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.Mock()
hs = self.setup_test_homeserver(
@@ -47,6 +53,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
+ self.msg_handler = hs.get_device_message_handler()
+ self.event_sources = hs.get_event_sources()
self.store = hs.get_datastores().main
return hs
@@ -398,6 +406,117 @@ class DeviceTestCase(unittest.HomeserverTestCase):
],
)
+ @override_config({"experimental_features": {"msc3944_enabled": True}})
+ def test_duplicated_and_cancelled_room_key_request(self) -> None:
+ myuser = self.register_user("myuser", "pass")
+ self.login("myuser", "pass", "device")
+ self.login("myuser", "pass", "device2")
+ self.login("myuser", "pass", "device3")
+
+ requester = requester = create_requester(myuser)
+
+ from_token = self.event_sources.get_current_token()
+
+ # This room_key_request is for device3 and should not be deleted.
+ self.get_success(
+ self.msg_handler.send_device_message(
+ requester,
+ "m.room_key_request",
+ {
+ myuser: {
+ "device3": {
+ "action": "request",
+ "request_id": "request_id",
+ "requesting_device_id": "device",
+ }
+ }
+ },
+ )
+ )
+
+ for _ in range(0, 2):
+ self.get_success(
+ self.msg_handler.send_device_message(
+ requester,
+ "m.room_key_request",
+ {
+ myuser: {
+ "device2": {
+ "action": "request",
+ "request_id": "request_id",
+ "requesting_device_id": "device",
+ }
+ }
+ },
+ )
+ )
+
+ to_token = self.event_sources.get_current_token()
+
+ # Test that if we queue 2 identical room_key_request,
+ # only one is delivered to the device.
+ res = self.get_success(
+ self.store.get_messages_for_device(
+ myuser,
+ "device2",
+ from_token.to_device_key,
+ to_token.to_device_key,
+ )
+ )
+ self.assertEqual(len(res[0]), 1)
+
+ # room_key_request for device3 should still be around.
+ res = self.get_success(
+ self.store.get_messages_for_device(
+ myuser,
+ "device3",
+ from_token.to_device_key,
+ to_token.to_device_key,
+ )
+ )
+ self.assertEqual(len(res[0]), 1)
+
+ self.get_success(
+ self.msg_handler.send_device_message(
+ requester,
+ "m.room_key_request",
+ {
+ myuser: {
+ "device2": {
+ "action": "request_cancellation",
+ "request_id": "request_id",
+ "requesting_device_id": "device",
+ }
+ }
+ },
+ )
+ )
+
+ to_token = self.event_sources.get_current_token()
+
+ # Test that if we cancel a room_key_request, both previous matching
+ # requests and the cancelled request are not delivered to the device.
+ res = self.get_success(
+ self.store.get_messages_for_device(
+ myuser,
+ "device2",
+ from_token.to_device_key,
+ to_token.to_device_key,
+ )
+ )
+ self.assertEqual(len(res[0]), 0)
+
+ # room_key_request for device3 should still be around.
+ res = self.get_success(
+ self.store.get_messages_for_device(
+ myuser,
+ "device3",
+ from_token.to_device_key,
+ to_token.to_device_key,
+ )
+ )
+ self.assertEqual(len(res[0]), 1)
+
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|