summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorMathieu Velten <mathieuv@matrix.org>2023-06-27 00:39:10 +0200
committerMathieu Velten <mathieuv@matrix.org>2023-06-27 09:49:42 +0200
commite25c15ea0f8c7356600c4ade27f7d92a420bea31 (patch)
tree40e4435cf01622e639bc88fd767ba88ddef705fd /tests
parentBump serde_json from 1.0.97 to 1.0.99 (#15832) (diff)
downloadsynapse-mv/msc3944.tar.xz
Implements part of MSC 3944 by dropping cancelled&duplicated `m.room_key_request` github/mv/msc3944 mv/msc3944
Diffstat (limited to '')
-rw-r--r--tests/handlers/test_device.py121
1 files changed, 120 insertions, 1 deletions
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ee48f9e546..0e34ee208a 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,13 +19,14 @@ from unittest import mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
+import synapse
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -37,6 +38,11 @@ user2 = "@theresa:bbb"
 
 
 class DeviceTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        synapse.rest.client.login.register_servlets,
+    ]
+
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.appservice_api = mock.Mock()
         hs = self.setup_test_homeserver(
@@ -47,6 +53,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
+        self.msg_handler = hs.get_device_message_handler()
+        self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastores().main
         return hs
 
@@ -398,6 +406,117 @@ class DeviceTestCase(unittest.HomeserverTestCase):
             ],
         )
 
+    @override_config({"experimental_features": {"msc3944_enabled": True}})
+    def test_duplicated_and_cancelled_room_key_request(self) -> None:
+        myuser = self.register_user("myuser", "pass")
+        self.login("myuser", "pass", "device")
+        self.login("myuser", "pass", "device2")
+        self.login("myuser", "pass", "device3")
+
+        requester = requester = create_requester(myuser)
+
+        from_token = self.event_sources.get_current_token()
+
+        # This room_key_request is for device3 and should not be deleted.
+        self.get_success(
+            self.msg_handler.send_device_message(
+                requester,
+                "m.room_key_request",
+                {
+                    myuser: {
+                        "device3": {
+                            "action": "request",
+                            "request_id": "request_id",
+                            "requesting_device_id": "device",
+                        }
+                    }
+                },
+            )
+        )
+
+        for _ in range(0, 2):
+            self.get_success(
+                self.msg_handler.send_device_message(
+                    requester,
+                    "m.room_key_request",
+                    {
+                        myuser: {
+                            "device2": {
+                                "action": "request",
+                                "request_id": "request_id",
+                                "requesting_device_id": "device",
+                            }
+                        }
+                    },
+                )
+            )
+
+            to_token = self.event_sources.get_current_token()
+
+            # Test that if we queue 2 identical room_key_request,
+            # only one is delivered to the device.
+            res = self.get_success(
+                self.store.get_messages_for_device(
+                    myuser,
+                    "device2",
+                    from_token.to_device_key,
+                    to_token.to_device_key,
+                )
+            )
+            self.assertEqual(len(res[0]), 1)
+
+        # room_key_request for device3 should still be around.
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device3",
+                from_token.to_device_key,
+                to_token.to_device_key,
+            )
+        )
+        self.assertEqual(len(res[0]), 1)
+
+        self.get_success(
+            self.msg_handler.send_device_message(
+                requester,
+                "m.room_key_request",
+                {
+                    myuser: {
+                        "device2": {
+                            "action": "request_cancellation",
+                            "request_id": "request_id",
+                            "requesting_device_id": "device",
+                        }
+                    }
+                },
+            )
+        )
+
+        to_token = self.event_sources.get_current_token()
+
+        # Test that if we cancel a room_key_request, both previous matching
+        # requests and the cancelled request are not delivered to the device.
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device2",
+                from_token.to_device_key,
+                to_token.to_device_key,
+            )
+        )
+        self.assertEqual(len(res[0]), 0)
+
+        # room_key_request for device3 should still be around.
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device3",
+                from_token.to_device_key,
+                to_token.to_device_key,
+            )
+        )
+        self.assertEqual(len(res[0]), 1)
+
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: