summary refs log tree commit diff
diff options
context:
space:
mode:
authorMathieu Velten <matmaul@gmail.com>2023-06-23 15:22:00 +0200
committerMathieu Velten <matmaul@gmail.com>2023-06-23 15:22:00 +0200
commit921fa8f9ceafdef2c204f600348e383403c8ddfe (patch)
tree4c2cd64aa7f62577b937b439accd894f8caae7b5
parentIgnore key requests if the device inbox is already big (diff)
downloadsynapse-mv/key_request_limit.tar.xz
-rw-r--r--synapse/storage/databases/main/deviceinbox.py14
-rw-r--r--tests/handlers/test_device.py85
2 files changed, 88 insertions, 11 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1eb4501c99..b13d5bd48e 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -816,7 +816,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
         local_by_user_then_device = {}
         for user_id, messages_by_device in messages_by_user_then_device.items():
-            inboxes_size = {}
+            inbox_sizes = {}
             if size_limit:
                 sql = """
                     SELECT device_id, COUNT(*) FROM device_inbox
@@ -825,7 +825,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 """
                 txn.execute(sql, (user_id,))
                 for r in txn:
-                    inboxes_size[r[0]] = r[1]
+                    inbox_sizes[r[0]] = r[1]
 
             messages_json_for_user = {}
             devices = list(messages_by_device.keys())
@@ -842,10 +842,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
                 message_json = json_encoder.encode(messages_by_device["*"])
                 for device_id in devices:
-                    if (
-                        size_limit is None
-                        or inboxes_size.get(device_id, 0) <= size_limit
-                    ):
+                    if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit:
                         # Add the message for all devices for this user on this
                         # server.
                         messages_json_for_user[device_id] = message_json
@@ -881,10 +878,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                         )
                         message_json = json_encoder.encode(msg)
 
-                    if (
-                        size_limit is None
-                        or inboxes_size.get(device_id, 0) <= size_limit
-                    ):
+                    if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit:
                         messages_json_for_user[device_id] = message_json
 
             if messages_json_for_user:
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ee48f9e546..821c31c7c0 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -23,20 +23,28 @@ from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
+from synapse.handlers.devicemessage import INBOX_SIZE_LIMIT_FOR_KEY_REQUEST
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
+import synapse
+
 user1 = "@boris:aaa"
 user2 = "@theresa:bbb"
 
 
 class DeviceTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        synapse.rest.client.login.register_servlets,
+    ]
+
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.appservice_api = mock.Mock()
         hs = self.setup_test_homeserver(
@@ -47,6 +55,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
+        self.msg_handler = hs.get_device_message_handler()
+        self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastores().main
         return hs
 
@@ -398,6 +408,79 @@ class DeviceTestCase(unittest.HomeserverTestCase):
             ],
         )
 
+    def test_room_key_request_limit(self) -> None:
+        store = self.hs.get_datastores().main
+
+        myuser = self.register_user("myuser", "pass")
+        self.login("myuser", "pass", "device")
+        self.login("myuser", "pass", "device2")
+
+        requester = requester = create_requester(myuser)
+
+        from_token = self.event_sources.get_current_token()
+
+        #     for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2):
+        #         self.get_success(
+        #             self.msg_handler.send_device_message(
+        #                 requester,
+        #                 "m.room_key",
+        #                 {
+        #                     myuser2: {
+        #                         "device": {
+        # "algorithm": "m.megolm.v1.aes-sha2",
+        # "room_id": "!Cuyf34gef24t:localhost",
+        # "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
+        # "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY..."
+
+        #                         }
+        #                     }
+        #                 },
+        #             )
+        #         )
+
+        #     to_token = self.event_sources.get_current_token()
+
+        # res = self.get_success(self.store.get_messages_for_device(
+        #     myuser2,
+        #     "device",
+        #     from_token.to_device_key,
+        #     to_token.to_device_key,
+        #     INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5,
+        # ))
+        # self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2)
+
+        # from_token = to_token
+
+        for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2):
+            self.get_success(
+                self.msg_handler.send_device_message(
+                    requester,
+                    "m.room_key_request",
+                    {
+                        myuser: {
+                            "device2": {
+                                "action": "request",
+                                "request_id": f"request_id_{i}",
+                                "requesting_device_id": "device",
+                            }
+                        }
+                    },
+                )
+            )
+
+        to_token = self.event_sources.get_current_token()
+
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device2",
+                from_token.to_device_key,
+                to_token.to_device_key,
+                INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5,
+            )
+        )
+        self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST)
+
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: