diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 1aa7d803b5..7db4f48965 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -16,7 +16,9 @@
import logging
from typing import TYPE_CHECKING, Any, Dict
+from synapse.api.constants import EduTypes
from synapse.api.errors import SynapseError
+from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
get_active_span_text_map,
@@ -25,7 +27,7 @@ from synapse.logging.opentracing import (
start_active_span,
)
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
@@ -78,6 +80,12 @@ class DeviceMessageHandler:
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
+ self._ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.rc_key_requests.per_second,
+ burst_count=hs.config.rc_key_requests.burst_count,
+ )
+
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {}
sender_user_id = content["sender"]
@@ -168,15 +176,27 @@ class DeviceMessageHandler:
async def send_device_message(
self,
- sender_user_id: str,
+ requester: Requester,
message_type: str,
messages: Dict[str, Dict[str, JsonDict]],
) -> None:
+ sender_user_id = requester.user.to_string()
+
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
for user_id, by_device in messages.items():
+ # Ratelimit local cross-user key requests by the sending device.
+ if (
+ message_type == EduTypes.RoomKeyRequest
+ and user_id != sender_user_id
+ and self._ratelimiter.can_do_action(
+ (sender_user_id, requester.device_id)
+ )
+ ):
+ continue
+
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 3e23f82cf7..f46cab7325 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -17,7 +17,7 @@ import logging
import random
from typing import TYPE_CHECKING, Iterable, List, Optional
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
@@ -113,7 +113,7 @@ class EventStreamHandler(BaseHandler):
states = await presence_handler.get_states(users)
to_add.extend(
{
- "type": EventTypes.Presence,
+ "type": EduTypes.Presence,
"content": format_user_presence_state(state, time_now),
}
for state in states
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 78c3e5a10b..71a5076672 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
@@ -412,7 +412,7 @@ class InitialSyncHandler(BaseHandler):
return [
{
- "type": EventTypes.Presence,
+ "type": EduTypes.Presence,
"content": format_user_presence_state(s, time_now),
}
for s in states
|