summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/devicemessage.py24
-rw-r--r--synapse/handlers/events.py4
-rw-r--r--synapse/handlers/initial_sync.py4
3 files changed, 26 insertions, 6 deletions
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 1aa7d803b5..7db4f48965 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -16,7 +16,9 @@
 import logging
 from typing import TYPE_CHECKING, Any, Dict
 
+from synapse.api.constants import EduTypes
 from synapse.api.errors import SynapseError
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
     get_active_span_text_map,
@@ -25,7 +27,7 @@ from synapse.logging.opentracing import (
     start_active_span,
 )
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.stringutils import random_string
 
@@ -78,6 +80,12 @@ class DeviceMessageHandler:
                 ReplicationUserDevicesResyncRestServlet.make_client(hs)
             )
 
+        self._ratelimiter = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.rc_key_requests.per_second,
+            burst_count=hs.config.rc_key_requests.burst_count,
+        )
+
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
         local_messages = {}
         sender_user_id = content["sender"]
@@ -168,15 +176,27 @@ class DeviceMessageHandler:
 
     async def send_device_message(
         self,
-        sender_user_id: str,
+        requester: Requester,
         message_type: str,
         messages: Dict[str, Dict[str, JsonDict]],
     ) -> None:
+        sender_user_id = requester.user.to_string()
+
         set_tag("number_of_messages", len(messages))
         set_tag("sender", sender_user_id)
         local_messages = {}
         remote_messages = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
         for user_id, by_device in messages.items():
+            # Ratelimit local cross-user key requests by the sending device.
+            if (
+                message_type == EduTypes.RoomKeyRequest
+                and user_id != sender_user_id
+                and self._ratelimiter.can_do_action(
+                    (sender_user_id, requester.device_id)
+                )
+            ):
+                continue
+
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
                 messages_by_device = {
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 3e23f82cf7..f46cab7325 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -17,7 +17,7 @@ import logging
 import random
 from typing import TYPE_CHECKING, Iterable, List, Optional
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import AuthError, SynapseError
 from synapse.events import EventBase
 from synapse.handlers.presence import format_user_presence_state
@@ -113,7 +113,7 @@ class EventStreamHandler(BaseHandler):
                     states = await presence_handler.get_states(users)
                     to_add.extend(
                         {
-                            "type": EventTypes.Presence,
+                            "type": EduTypes.Presence,
                             "content": format_user_presence_state(state, time_now),
                         }
                         for state in states
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 78c3e5a10b..71a5076672 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
 
 from twisted.internet import defer
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.events.validator import EventValidator
 from synapse.handlers.presence import format_user_presence_state
@@ -412,7 +412,7 @@ class InitialSyncHandler(BaseHandler):
 
             return [
                 {
-                    "type": EventTypes.Presence,
+                    "type": EduTypes.Presence,
                     "content": format_user_presence_state(s, time_now),
                 }
                 for s in states