diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index c5d631de07..580b941595 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict
-from synapse.api.constants import EduTypes
+from synapse.api.constants import ToDeviceEventTypes
from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
@@ -79,6 +79,8 @@ class DeviceMessageHandler:
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
+ # a rate limiter for room key requests. The keys are
+ # (sending_user_id, sending_device_id).
self._ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
@@ -100,12 +102,25 @@ class DeviceMessageHandler:
for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("Request for keys for non-local user %s", user_id)
+ logger.warning("To-device message to non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
if not by_device:
continue
+ # Ratelimit key requests by the sending user.
+ if message_type == ToDeviceEventTypes.RoomKeyRequest:
+ allowed, _ = await self._ratelimiter.can_do_action(
+ None, (sender_user_id, None)
+ )
+ if not allowed:
+ logger.info(
+ "Dropping room_key_request from %s to %s due to rate limit",
+ sender_user_id,
+ user_id,
+ )
+ continue
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -192,13 +207,19 @@ class DeviceMessageHandler:
for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (
- message_type == EduTypes.RoomKeyRequest
+ message_type == ToDeviceEventTypes.RoomKeyRequest
and user_id != sender_user_id
- and await self._ratelimiter.can_do_action(
+ ):
+ allowed, _ = await self._ratelimiter.can_do_action(
requester, (sender_user_id, requester.device_id)
)
- ):
- continue
+ if not allowed:
+ logger.info(
+ "Dropping room_key_request from %s to %s due to rate limit",
+ sender_user_id,
+ user_id,
+ )
+ continue
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
|