diff options
-rw-r--r-- | synapse/storage/databases/main/deviceinbox.py | 14 | ||||
-rw-r--r-- | tests/handlers/test_device.py | 85 |
2 files changed, 88 insertions, 11 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 1eb4501c99..b13d5bd48e 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -816,7 +816,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): - inboxes_size = {} + inbox_sizes = {} if size_limit: sql = """ SELECT device_id, COUNT(*) FROM device_inbox @@ -825,7 +825,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): """ txn.execute(sql, (user_id,)) for r in txn: - inboxes_size[r[0]] = r[1] + inbox_sizes[r[0]] = r[1] messages_json_for_user = {} devices = list(messages_by_device.keys()) @@ -842,10 +842,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): message_json = json_encoder.encode(messages_by_device["*"]) for device_id in devices: - if ( - size_limit is None - or inboxes_size.get(device_id, 0) <= size_limit - ): + if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit: # Add the message for all devices for this user on this # server. messages_json_for_user[device_id] = message_json @@ -881,10 +878,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) message_json = json_encoder.encode(msg) - if ( - size_limit is None - or inboxes_size.get(device_id, 0) <= size_limit - ): + if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit: messages_json_for_user[device_id] = message_json if messages_json_for_user: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ee48f9e546..821c31c7c0 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -23,20 +23,28 @@ from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import NotFoundError, SynapseError from synapse.appservice import ApplicationService from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler +from synapse.handlers.devicemessage import INBOX_SIZE_LIMIT_FOR_KEY_REQUEST from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex -from synapse.types import JsonDict +from synapse.types import JsonDict, create_requester from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable from tests.unittest import override_config +import synapse + user1 = "@boris:aaa" user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.appservice_api = mock.Mock() hs = self.setup_test_homeserver( @@ -47,6 +55,8 @@ class DeviceTestCase(unittest.HomeserverTestCase): handler = hs.get_device_handler() assert isinstance(handler, DeviceHandler) self.handler = handler + self.msg_handler = hs.get_device_message_handler() + self.event_sources = hs.get_event_sources() self.store = hs.get_datastores().main return hs @@ -398,6 +408,79 @@ class DeviceTestCase(unittest.HomeserverTestCase): ], ) + def test_room_key_request_limit(self) -> None: + store = self.hs.get_datastores().main + + myuser = self.register_user("myuser", "pass") + self.login("myuser", "pass", "device") + self.login("myuser", "pass", "device2") + + requester = requester = create_requester(myuser) + + from_token = self.event_sources.get_current_token() + + # for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2): + # self.get_success( + # self.msg_handler.send_device_message( + # requester, + # "m.room_key", + # { + # myuser2: { + # "device": { + # "algorithm": "m.megolm.v1.aes-sha2", + # "room_id": "!Cuyf34gef24t:localhost", + # "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ", + # "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY..." + + # } + # } + # }, + # ) + # ) + + # to_token = self.event_sources.get_current_token() + + # res = self.get_success(self.store.get_messages_for_device( + # myuser2, + # "device", + # from_token.to_device_key, + # to_token.to_device_key, + # INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5, + # )) + # self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2) + + # from_token = to_token + + for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2): + self.get_success( + self.msg_handler.send_device_message( + requester, + "m.room_key_request", + { + myuser: { + "device2": { + "action": "request", + "request_id": f"request_id_{i}", + "requesting_device_id": "device", + } + } + }, + ) + ) + + to_token = self.event_sources.get_current_token() + + res = self.get_success( + self.store.get_messages_for_device( + myuser, + "device2", + from_token.to_device_key, + to_token.to_device_key, + INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5, + ) + ) + self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST) + class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: |