diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 39e57a4503..4634f4df9d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,13 +16,14 @@
import abc
import logging
from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from synapse import types
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
+from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase
@@ -36,6 +37,10 @@ from synapse.util.distributor import user_joined_room, user_left_room
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
@@ -47,7 +52,7 @@ class RoomMemberHandler(object):
__metaclass__ = abc.ABCMeta
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -78,6 +83,17 @@ class RoomMemberHandler(object):
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
+ self._join_rate_limiter_local = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
+ burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
+ )
+ self._join_rate_limiter_remote = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
+ burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
+ )
+
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state.
@@ -196,7 +212,7 @@ class RoomMemberHandler(object):
return duplicate.event_id, stream_id
stream_id = await self.event_creation_handler.handle_new_client_event(
- requester, event, context, extra_users=[target], ratelimit=ratelimit
+ requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
prev_state_ids = await context.get_prev_state_ids()
@@ -461,7 +477,28 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
- if not is_host_in_room:
+ if is_host_in_room:
+ time_now_s = self.clock.time()
+ allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
+ requester.user.to_string(),
+ )
+
+ if not allowed:
+ raise LimitExceededError(
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ )
+
+ else:
+ time_now_s = self.clock.time()
+ allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
+ requester.user.to_string(),
+ )
+
+ if not allowed:
+ raise LimitExceededError(
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ )
+
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@@ -987,7 +1024,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
check_complexity = self.hs.config.limit_remote_rooms.enabled
if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join:
- check_complexity = not await self.hs.auth.is_server_admin(user)
+ check_complexity = not await self.auth.is_server_admin(user)
if check_complexity:
# Fetch the room complexity
|