diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 04c44b2ccb..30b4cb23df 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -94,12 +94,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
)
+ # Tracks joins from local users to rooms this server isn't a member of.
+ # I.e. joins this server makes by requesting /make_join /send_join from
+ # another server.
self._join_rate_limiter_remote = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
+ # TODO: find a better place to keep this Ratelimiter.
+ # It needs to be
+ # - written to by event persistence code
+ # - written to by something which can snoop on replication streams
+ # - read by the RoomMemberHandler to rate limit joins from local users
+ # - read by the FederationServer to rate limit make_joins and send_joins from
+ # other homeservers
+ # I wonder if a homeserver-wide collection of rate limiters might be cleaner?
+ self._join_rate_per_room_limiter = Ratelimiter(
+ store=self.store,
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
+ burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
+ )
# Ratelimiter for invites, keyed by room (across all issuers, all
# recipients).
@@ -136,6 +153,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
self.request_ratelimiter = hs.get_request_ratelimiter()
+ hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
+
+ def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
+ """Notify the rate limiter that a room join has occurred.
+
+ Use this to inform the RoomMemberHandler about joins that have either
+ - taken place on another homeserver, or
+ - on another worker in this homeserver.
+ Joins actioned by this worker should use the usual `ratelimit` method, which
+ checks the limit and increments the counter in one go.
+ """
+ self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
@abc.abstractmethod
async def _remote_join(
@@ -285,6 +314,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
@@ -315,6 +345,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
txn_id:
ratelimit:
@@ -370,6 +403,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
+ depth=depth,
require_consent=require_consent,
outlier=outlier,
historical=historical,
@@ -391,6 +425,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# up blocking profile updates.
if newly_joined and ratelimit:
await self._join_rate_limiter_local.ratelimit(requester)
+ await self._join_rate_per_room_limiter.ratelimit(
+ requester, key=room_id, update=False
+ )
result_event = await self.event_creation_handler.handle_new_client_event(
requester,
@@ -466,6 +503,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@@ -501,6 +539,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@@ -540,6 +581,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
+ depth=depth,
)
return result
@@ -562,6 +604,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
state_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
) -> Tuple[str, int]:
"""Helper for update_membership.
@@ -599,6 +642,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
prev_events are set so we need to set them ourself via this argument.
This should normally be left as None, which will cause the auth_event_ids
to be calculated based on the room state at the prev_events.
+ depth: Override the depth used to order the event in the DAG.
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@@ -732,6 +778,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
state_event_ids=state_event_ids,
+ depth=depth,
content=content,
require_consent=require_consent,
outlier=outlier,
@@ -740,14 +787,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- current_state_ids = await self.state_handler.get_current_state_ids(
- room_id, latest_event_ids=latest_event_ids
+ state_before_join = await self.state_handler.compute_state_after_events(
+ room_id, latest_event_ids
)
# TODO: Refactor into dictionary of explicitly allowed transitions
# between old and new state, with specific error messages for some
# transitions and generic otherwise
- old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+ old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
@@ -798,11 +845,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
- is_host_in_room = await self._is_host_in_room(current_state_ids)
+ is_host_in_room = await self._is_host_in_room(state_before_join)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
- guest_can_join = await self._can_guest_join(current_state_ids)
+ guest_can_join = await self._can_guest_join(state_before_join)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@@ -840,13 +887,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Check if a remote join should be performed.
remote_join, remote_room_hosts = await self._should_perform_remote_join(
- target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+ target.to_string(),
+ room_id,
+ remote_room_hosts,
+ content,
+ is_host_in_room,
+ state_before_join,
)
if remote_join:
if ratelimit:
await self._join_rate_limiter_remote.ratelimit(
requester,
)
+ await self._join_rate_per_room_limiter.ratelimit(
+ requester,
+ key=room_id,
+ update=False,
+ )
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
@@ -967,6 +1024,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
state_event_ids=state_event_ids,
+ depth=depth,
content=content,
require_consent=require_consent,
outlier=outlier,
@@ -979,6 +1037,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
remote_room_hosts: List[str],
content: JsonDict,
is_host_in_room: bool,
+ state_before_join: StateMap[str],
) -> Tuple[bool, List[str]]:
"""
Check whether the server should do a remote join (as opposed to a local
@@ -998,6 +1057,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
content: The content to use as the event body of the join. This may
be modified.
is_host_in_room: True if the host is in the room.
+ state_before_join: The state before the join event (i.e. the resolution of
+ the states after its parent events).
Returns:
A tuple of:
@@ -1014,20 +1075,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
- current_state_ids = await self._storage_controllers.state.get_current_state_ids(
- room_id
- )
# If restricted join rules are not being used, a local join can always
# be used.
if not await self.event_auth_handler.has_restricted_join_rules(
- current_state_ids, room_version
+ state_before_join, room_version
):
return False, []
# If the user is invited to the room or already joined, the join
# event can always be issued locally.
- prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -1042,10 +1100,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
#
# If not, generate a new list of remote hosts based on which
# can issue invites.
- event_map = await self.store.get_events(current_state_ids.values())
+ event_map = await self.store.get_events(state_before_join.values())
current_state = {
state_key: event_map[event_id]
- for state_key, event_id in current_state_ids.items()
+ for state_key, event_id in state_before_join.items()
}
allowed_servers = get_servers_from_users(
get_users_which_can_issue_invite(current_state)
@@ -1059,7 +1117,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
- current_state_ids, room_version, user_id, prev_member_event
+ state_before_join, room_version, user_id, prev_member_event
)
# If this is going to be a local join, additional information must
@@ -1069,7 +1127,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
EventContentFields.AUTHORISING_USER
] = await self.event_auth_handler.get_user_which_could_invite(
room_id,
- current_state_ids,
+ state_before_join,
)
return False, []
@@ -1322,7 +1380,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
requester: Requester,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
- ) -> int:
+ prev_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
+ ) -> Tuple[str, int]:
"""Invite a 3PID to a room.
Args:
@@ -1335,9 +1395,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: The transaction ID this is part of, or None if this is not
part of a transaction.
id_access_token: The optional identity server access token.
+ depth: Override the depth used to order the event in the DAG.
+ prev_event_ids: The event IDs to use as the prev events
+ Should normally be set to None, which will cause the depth to be calculated
+ based on the prev_events.
Returns:
- The new stream ID.
+ Tuple of event ID and stream ordering position
Raises:
ShadowBanError if the requester has been shadow-banned.
@@ -1383,7 +1447,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We don't check the invite against the spamchecker(s) here (through
# user_may_invite) because we'll do it further down the line anyway (in
# update_membership_locked).
- _, stream_id = await self.update_membership(
+ event_id, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
@@ -1402,7 +1466,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
additional_fields=spam_check[1],
)
- stream_id = await self._make_and_store_3pid_invite(
+ event, stream_id = await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@@ -1411,9 +1475,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
inviter,
txn_id=txn_id,
id_access_token=id_access_token,
+ prev_event_ids=prev_event_ids,
+ depth=depth,
)
+ event_id = event.event_id
- return stream_id
+ return event_id, stream_id
async def _make_and_store_3pid_invite(
self,
@@ -1425,7 +1492,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
user: UserID,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
- ) -> int:
+ prev_event_ids: Optional[List[str]] = None,
+ depth: Optional[int] = None,
+ ) -> Tuple[EventBase, int]:
room_state = await self._storage_controllers.state.get_current_state(
room_id,
StateFilter.from_types(
@@ -1518,8 +1587,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
},
ratelimit=False,
txn_id=txn_id,
+ prev_event_ids=prev_event_ids,
+ depth=depth,
)
- return stream_id
+ return event, stream_id
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very
|