diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1eb4501c99..b13d5bd48e 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -816,7 +816,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
- inboxes_size = {}
+ inbox_sizes = {}
if size_limit:
sql = """
SELECT device_id, COUNT(*) FROM device_inbox
@@ -825,7 +825,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (user_id,))
for r in txn:
- inboxes_size[r[0]] = r[1]
+ inbox_sizes[r[0]] = r[1]
messages_json_for_user = {}
devices = list(messages_by_device.keys())
@@ -842,10 +842,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
message_json = json_encoder.encode(messages_by_device["*"])
for device_id in devices:
- if (
- size_limit is None
- or inboxes_size.get(device_id, 0) <= size_limit
- ):
+ if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit:
# Add the message for all devices for this user on this
# server.
messages_json_for_user[device_id] = message_json
@@ -881,10 +878,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
message_json = json_encoder.encode(msg)
- if (
- size_limit is None
- or inboxes_size.get(device_id, 0) <= size_limit
- ):
+ if size_limit is None or inbox_sizes.get(device_id, 0) < size_limit:
messages_json_for_user[device_id] = message_json
if messages_json_for_user:
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ee48f9e546..821c31c7c0 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -23,20 +23,28 @@ from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError
from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
+from synapse.handlers.devicemessage import INBOX_SIZE_LIMIT_FOR_KEY_REQUEST
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
+import synapse
+
user1 = "@boris:aaa"
user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ ]
+
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.Mock()
hs = self.setup_test_homeserver(
@@ -47,6 +55,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
+ self.msg_handler = hs.get_device_message_handler()
+ self.event_sources = hs.get_event_sources()
self.store = hs.get_datastores().main
return hs
@@ -398,6 +408,79 @@ class DeviceTestCase(unittest.HomeserverTestCase):
],
)
+ def test_room_key_request_limit(self) -> None:
+ store = self.hs.get_datastores().main
+
+ myuser = self.register_user("myuser", "pass")
+ self.login("myuser", "pass", "device")
+ self.login("myuser", "pass", "device2")
+
+ requester = requester = create_requester(myuser)
+
+ from_token = self.event_sources.get_current_token()
+
+ # for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2):
+ # self.get_success(
+ # self.msg_handler.send_device_message(
+ # requester,
+ # "m.room_key",
+ # {
+ # myuser2: {
+ # "device": {
+ # "algorithm": "m.megolm.v1.aes-sha2",
+ # "room_id": "!Cuyf34gef24t:localhost",
+ # "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
+ # "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY..."
+
+ # }
+ # }
+ # },
+ # )
+ # )
+
+ # to_token = self.event_sources.get_current_token()
+
+ # res = self.get_success(self.store.get_messages_for_device(
+ # myuser2,
+ # "device",
+ # from_token.to_device_key,
+ # to_token.to_device_key,
+ # INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5,
+ # ))
+ # self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2)
+
+ # from_token = to_token
+
+ for i in range(0, INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 2):
+ self.get_success(
+ self.msg_handler.send_device_message(
+ requester,
+ "m.room_key_request",
+ {
+ myuser: {
+ "device2": {
+ "action": "request",
+ "request_id": f"request_id_{i}",
+ "requesting_device_id": "device",
+ }
+ }
+ },
+ )
+ )
+
+ to_token = self.event_sources.get_current_token()
+
+ res = self.get_success(
+ self.store.get_messages_for_device(
+ myuser,
+ "device2",
+ from_token.to_device_key,
+ to_token.to_device_key,
+ INBOX_SIZE_LIMIT_FOR_KEY_REQUEST * 5,
+ )
+ )
+ self.assertEqual(len(res[0]), INBOX_SIZE_LIMIT_FOR_KEY_REQUEST)
+
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|