summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-05-11 11:02:56 +0100
committerGitHub <noreply@github.com>2021-05-11 11:02:56 +0100
commit7967b36efe6a033f46cd882d0b31a8c3eb18631c (patch)
tree2949141c3cb6180ff1f3ce6ed5bdb54b4b575e31 /synapse
parentImprove performance of backfilling in large rooms. (#9935) (diff)
downloadsynapse-7967b36efe6a033f46cd882d0b31a8c3eb18631c.tar.xz
Fix `m.room_key_request` to-device messages (#9961)
fixes #9960 
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py5
-rw-r--r--synapse/federation/federation_server.py19
-rw-r--r--synapse/handlers/devicemessage.py33
3 files changed, 31 insertions, 26 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index ab628b2be7..3940da5c88 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -116,9 +116,12 @@ class EventTypes:
     MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent"
 
 
+class ToDeviceEventTypes:
+    RoomKeyRequest = "m.room_key_request"
+
+
 class EduTypes:
     Presence = "m.presence"
-    RoomKeyRequest = "m.room_key_request"
 
 
 class RejectedReason:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index b729a69203..ace30aa450 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -44,7 +44,6 @@ from synapse.api.errors import (
     SynapseError,
     UnsupportedRoomVersionError,
 )
-from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@@ -865,14 +864,6 @@ class FederationHandlerRegistry:
         # EDU received.
         self._edu_type_to_instance = {}  # type: Dict[str, List[str]]
 
-        # A rate limiter for incoming room key requests per origin.
-        self._room_key_request_rate_limiter = Ratelimiter(
-            store=hs.get_datastore(),
-            clock=self.clock,
-            rate_hz=self.config.rc_key_requests.per_second,
-            burst_count=self.config.rc_key_requests.burst_count,
-        )
-
     def register_edu_handler(
         self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
     ) -> None:
@@ -926,16 +917,6 @@ class FederationHandlerRegistry:
         if not self.config.use_presence and edu_type == EduTypes.Presence:
             return
 
-        # If the incoming room key requests from a particular origin are over
-        # the limit, drop them.
-        if (
-            edu_type == EduTypes.RoomKeyRequest
-            and not await self._room_key_request_rate_limiter.can_do_action(
-                None, origin
-            )
-        ):
-            return
-
         # Check if we have a handler on this instance
         handler = self.edu_handlers.get(edu_type)
         if handler:
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index c5d631de07..580b941595 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Any, Dict
 
-from synapse.api.constants import EduTypes
+from synapse.api.constants import ToDeviceEventTypes
 from synapse.api.errors import SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.logging.context import run_in_background
@@ -79,6 +79,8 @@ class DeviceMessageHandler:
                 ReplicationUserDevicesResyncRestServlet.make_client(hs)
             )
 
+        # a rate limiter for room key requests.  The keys are
+        # (sending_user_id, sending_device_id).
         self._ratelimiter = Ratelimiter(
             store=self.store,
             clock=hs.get_clock(),
@@ -100,12 +102,25 @@ class DeviceMessageHandler:
         for user_id, by_device in content["messages"].items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
-                logger.warning("Request for keys for non-local user %s", user_id)
+                logger.warning("To-device message to non-local user %s", user_id)
                 raise SynapseError(400, "Not a user here")
 
             if not by_device:
                 continue
 
+            # Ratelimit key requests by the sending user.
+            if message_type == ToDeviceEventTypes.RoomKeyRequest:
+                allowed, _ = await self._ratelimiter.can_do_action(
+                    None, (sender_user_id, None)
+                )
+                if not allowed:
+                    logger.info(
+                        "Dropping room_key_request from %s to %s due to rate limit",
+                        sender_user_id,
+                        user_id,
+                    )
+                    continue
+
             messages_by_device = {
                 device_id: {
                     "content": message_content,
@@ -192,13 +207,19 @@ class DeviceMessageHandler:
         for user_id, by_device in messages.items():
             # Ratelimit local cross-user key requests by the sending device.
             if (
-                message_type == EduTypes.RoomKeyRequest
+                message_type == ToDeviceEventTypes.RoomKeyRequest
                 and user_id != sender_user_id
-                and await self._ratelimiter.can_do_action(
+            ):
+                allowed, _ = await self._ratelimiter.can_do_action(
                     requester, (sender_user_id, requester.device_id)
                 )
-            ):
-                continue
+                if not allowed:
+                    logger.info(
+                        "Dropping room_key_request from %s to %s due to rate limit",
+                        sender_user_id,
+                        user_id,
+                    )
+                    continue
 
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):