summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/devicemessage.py33
1 files changed, 27 insertions, 6 deletions
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index c5d631de07..580b941595 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Any, Dict
 
-from synapse.api.constants import EduTypes
+from synapse.api.constants import ToDeviceEventTypes
 from synapse.api.errors import SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.logging.context import run_in_background
@@ -79,6 +79,8 @@ class DeviceMessageHandler:
                 ReplicationUserDevicesResyncRestServlet.make_client(hs)
             )
 
+        # a rate limiter for room key requests.  The keys are
+        # (sending_user_id, sending_device_id).
         self._ratelimiter = Ratelimiter(
             store=self.store,
             clock=hs.get_clock(),
@@ -100,12 +102,25 @@ class DeviceMessageHandler:
         for user_id, by_device in content["messages"].items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
-                logger.warning("Request for keys for non-local user %s", user_id)
+                logger.warning("To-device message to non-local user %s", user_id)
                 raise SynapseError(400, "Not a user here")
 
             if not by_device:
                 continue
 
+            # Ratelimit key requests by the sending user.
+            if message_type == ToDeviceEventTypes.RoomKeyRequest:
+                allowed, _ = await self._ratelimiter.can_do_action(
+                    None, (sender_user_id, None)
+                )
+                if not allowed:
+                    logger.info(
+                        "Dropping room_key_request from %s to %s due to rate limit",
+                        sender_user_id,
+                        user_id,
+                    )
+                    continue
+
             messages_by_device = {
                 device_id: {
                     "content": message_content,
@@ -192,13 +207,19 @@ class DeviceMessageHandler:
         for user_id, by_device in messages.items():
             # Ratelimit local cross-user key requests by the sending device.
             if (
-                message_type == EduTypes.RoomKeyRequest
+                message_type == ToDeviceEventTypes.RoomKeyRequest
                 and user_id != sender_user_id
-                and await self._ratelimiter.can_do_action(
+            ):
+                allowed, _ = await self._ratelimiter.can_do_action(
                     requester, (sender_user_id, requester.device_id)
                 )
-            ):
-                continue
+                if not allowed:
+                    logger.info(
+                        "Dropping room_key_request from %s to %s due to rate limit",
+                        sender_user_id,
+                        user_id,
+                    )
+                    continue
 
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):