diff --git a/changelog.d/16441.misc b/changelog.d/16441.misc
new file mode 100644
index 0000000000..32264a62b2
--- /dev/null
+++ b/changelog.d/16441.misc
@@ -0,0 +1 @@
+Improve rate limiting logic.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 90343c2306..1b50495af1 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -382,8 +382,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
and persist a new event for the new membership change.
Args:
- requester:
- target:
+ requester: User requesting the membership change, i.e. the sender of the
+ desired membership event.
+ target: Use whose membership should change, i.e. the state_key of the
+ desired membership event.
room_id:
membership:
@@ -415,7 +417,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Returns:
Tuple of event ID and stream ordering position
"""
-
user_id = target.to_string()
if content is None:
@@ -475,21 +476,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
(EventTypes.Member, user_id), None
)
- if event.membership == Membership.JOIN:
- newly_joined = True
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(
- prev_member_event_id
- )
- newly_joined = prev_member_event.membership != Membership.JOIN
-
- # Only rate-limit if the user actually joined the room, otherwise we'll end
- # up blocking profile updates.
- if newly_joined and ratelimit:
- await self._join_rate_limiter_local.ratelimit(requester)
- await self._join_rate_per_room_limiter.ratelimit(
- requester, key=room_id, update=False
- )
with opentracing.start_active_span("handle_new_client_event"):
result_event = (
await self.event_creation_handler.handle_new_client_event(
@@ -618,6 +604,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
Raises:
ShadowBanError if a shadow-banned requester attempts to send an invite.
"""
+ if ratelimit:
+ if action == Membership.JOIN:
+ # Only rate-limit if the user isn't already joined to the room, otherwise
+ # we'll end up blocking profile updates.
+ (
+ current_membership,
+ _,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ requester.user.to_string(),
+ room_id,
+ )
+ if current_membership != Membership.JOIN:
+ await self._join_rate_limiter_local.ratelimit(requester)
+ await self._join_rate_per_room_limiter.ratelimit(
+ requester, key=room_id, update=False
+ )
+ elif action == Membership.INVITE:
+ await self.ratelimit_invite(requester, room_id, target.to_string())
+
if action == Membership.INVITE and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
@@ -794,8 +799,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if effective_membership_state == Membership.INVITE:
target_id = target.to_string()
- if ratelimit:
- await self.ratelimit_invite(requester, room_id, target_id)
# block any attempts to invite the server notices mxid
if target_id == self._server_notices_mxid:
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 7627823d3f..aaa4f3bba0 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -1447,6 +1447,30 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
+ def test_join_attempts_local_ratelimit(self) -> None:
+ """Tests that unsuccessful joins that end up being denied are rate-limited."""
+ # Create 4 rooms
+ room_ids = [
+ self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4)
+ ]
+ # Pre-emptively ban the user who will attempt to join.
+ joiner_user_id = self.register_user("joiner", "secret")
+ for room_id in room_ids:
+ self.helper.ban(room_id, self.user_id, joiner_user_id)
+
+ # Now make a new user try to join some of them.
+ # The user can make 3 requests, each of which should be denied.
+ for room_id in room_ids[0:3]:
+ self.helper.join(room_id, joiner_user_id, expect_code=HTTPStatus.FORBIDDEN)
+
+ # The fourth attempt should be rate limited.
+ self.helper.join(
+ room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS
+ )
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
def test_join_local_ratelimit_profile_change(self) -> None:
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
|