summary refs log tree commit diff
diff options
context:
space:
mode:
authorMathieu Velten <mathieuv@matrix.org>2023-06-20 14:22:16 +0200
committerMathieu Velten <matmaul@gmail.com>2023-06-20 21:37:49 +0200
commit5047c01d3f1ace45edc124e4181f4c86df8cdb88 (patch)
tree0c6185ee4c3879d9caa41e105208cb174c8837b5
parentFix admin api documentation typo (#15805) (diff)
downloadsynapse-5047c01d3f1ace45edc124e4181f4c86df8cdb88.tar.xz
Ignore key requests if the device inbox is already big
-rw-r--r--changelog.d/15808.misc1
-rw-r--r--synapse/handlers/devicemessage.py13
-rw-r--r--synapse/storage/databases/main/deviceinbox.py42
3 files changed, 47 insertions, 9 deletions
diff --git a/changelog.d/15808.misc b/changelog.d/15808.misc
new file mode 100644
index 0000000000..a0829c50fe
--- /dev/null
+++ b/changelog.d/15808.misc
@@ -0,0 +1 @@
+Ignore key request if the device inbox is already big.
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 3caf9b31cc..d5fed8cadf 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -39,6 +39,9 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+INBOX_SIZE_LIMIT_FOR_KEY_REQUEST = 100
+
+
 class DeviceMessageHandler:
     def __init__(self, hs: "HomeServer"):
         """
@@ -166,7 +169,7 @@ class DeviceMessageHandler:
         found marks the remote cache for the user as stale.
         """
 
-        if message_type != "m.room_key_request":
+        if message_type != ToDeviceEventTypes.RoomKeyRequest:
             return
 
         # Get the sending device IDs
@@ -286,10 +289,16 @@ class DeviceMessageHandler:
                 "org.matrix.opentracing_context": json_encoder.encode(context),
             }
 
+        device_inbox_size_limit = None
+        if message_type == ToDeviceEventTypes.RoomKeyRequest and self.is_mine(
+            UserID.from_string(user_id)
+        ):
+            device_inbox_size_limit = INBOX_SIZE_LIMIT_FOR_KEY_REQUEST
+
         # Add messages to the database.
         # Retrieve the stream id of the last-processed to-device message.
         last_stream_id = await self.store.add_messages_to_device_inbox(
-            local_messages, remote_edu_contents
+            local_messages, remote_edu_contents, device_inbox_size_limit
         )
 
         # Notify listeners that there are new to-device messages to process,
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index b471fcb064..1eb4501c99 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -650,6 +650,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         self,
         local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
         remote_messages_by_destination: Dict[str, JsonDict],
+        size_limit: Optional[int] = None,
     ) -> int:
         """Used to send messages from this server.
 
@@ -666,11 +667,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         assert self._can_write_to_device
 
         def add_messages_txn(
-            txn: LoggingTransaction, now_ms: int, stream_id: int
+            txn: LoggingTransaction,
+            now_ms: int,
+            stream_id: int,
+            size_limit: Optional[int],
         ) -> None:
             # Add the local messages directly to the local inbox.
             self._add_messages_to_local_device_inbox_txn(
-                txn, stream_id, local_messages_by_user_then_device
+                txn, stream_id, local_messages_by_user_then_device, size_limit
             )
 
             # Add the remote messages to the federation outbox.
@@ -731,7 +735,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self._clock.time_msec()
             await self.db_pool.runInteraction(
-                "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
+                "add_messages_to_device_inbox",
+                add_messages_txn,
+                now_ms,
+                stream_id,
+                size_limit,
             )
             for user_id in local_messages_by_user_then_device.keys():
                 self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
@@ -802,11 +810,23 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         txn: LoggingTransaction,
         stream_id: int,
         messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+        size_limit: Optional[int] = None,
     ) -> None:
         assert self._can_write_to_device
 
         local_by_user_then_device = {}
         for user_id, messages_by_device in messages_by_user_then_device.items():
+            inboxes_size = {}
+            if size_limit:
+                sql = """
+                    SELECT device_id, COUNT(*) FROM device_inbox
+                    WHERE user_id = ?
+                    GROUP BY device_id
+                """
+                txn.execute(sql, (user_id,))
+                for r in txn:
+                    inboxes_size[r[0]] = r[1]
+
             messages_json_for_user = {}
             devices = list(messages_by_device.keys())
             if len(devices) == 1 and devices[0] == "*":
@@ -822,9 +842,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
                 message_json = json_encoder.encode(messages_by_device["*"])
                 for device_id in devices:
-                    # Add the message for all devices for this user on this
-                    # server.
-                    messages_json_for_user[device_id] = message_json
+                    if (
+                        size_limit is None
+                        or inboxes_size.get(device_id, 0) <= size_limit
+                    ):
+                        # Add the message for all devices for this user on this
+                        # server.
+                        messages_json_for_user[device_id] = message_json
             else:
                 if not devices:
                     continue
@@ -857,7 +881,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                         )
                         message_json = json_encoder.encode(msg)
 
-                    messages_json_for_user[device_id] = message_json
+                    if (
+                        size_limit is None
+                        or inboxes_size.get(device_id, 0) <= size_limit
+                    ):
+                        messages_json_for_user[device_id] = message_json
 
             if messages_json_for_user:
                 local_by_user_then_device[user_id] = messages_json_for_user