diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4fc1784efe..5a91917b4a 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -112,6 +112,13 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 10},
)
+ # Track the rate of joins to a given room. If there are too many, temporarily
+ # prevent local joins and remote joins via this server.
+ self.rc_joins_per_room = RateLimitConfig(
+ config.get("rc_joins_per_room", {}),
+ defaults={"per_second": 1, "burst_count": 10},
+ )
+
# Ratelimit cross-user key requests:
# * For local requests this is keyed by the sending device.
# * For requests received over federation this is keyed by the origin.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5dfdc86740..ae550d3f4d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -118,6 +118,7 @@ class FederationServer(FederationBase):
self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
+ self._room_member_handler = hs.get_room_member_handler()
self._state_storage_controller = hs.get_storage_controllers().state
@@ -621,6 +622,15 @@ class FederationServer(FederationBase):
)
raise IncompatibleRoomVersionError(room_version=room_version)
+ # Refuse the request if that room has seen too many joins recently.
+ # This is in addition to the HS-level rate limiting applied by
+ # BaseFederationServlet.
+ # type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
+ await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
+ requester=None,
+ key=room_id,
+ update=False,
+ )
pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
@@ -655,6 +665,12 @@ class FederationServer(FederationBase):
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]:
+ await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
+ requester=None,
+ key=room_id,
+ update=False,
+ )
+
event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index b1dab57447..766d9849f5 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1980,6 +1980,10 @@ class FederationEventHandler:
event, event_pos, max_stream_token, extra_users=extra_users
)
+ if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+ # TODO retrieve the previous state, and exclude join -> join transitions
+ self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 85abe71ea8..bd7baef051 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -463,6 +463,7 @@ class EventCreationHandler:
)
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
+ self._notifier = hs.get_notifier()
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
@@ -1550,6 +1551,16 @@ class EventCreationHandler:
requester, is_admin_redaction=is_admin_redaction
)
+ if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+ (
+ current_membership,
+ _,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ event.state_key, event.room_id
+ )
+ if current_membership != Membership.JOIN:
+ self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
await self._maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a5b9ac904e..30b4cb23df 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -94,12 +94,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
)
+ # Tracks joins from local users to rooms this server isn't a member of.
+ # I.e. joins this server makes by requesting /make_join /send_join from
+ # another server.
self._join_rate_limiter_remote = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
+ # TODO: find a better place to keep this Ratelimiter.
+ # It needs to be
+ # - written to by event persistence code
+ # - written to by something which can snoop on replication streams
+ # - read by the RoomMemberHandler to rate limit joins from local users
+ # - read by the FederationServer to rate limit make_joins and send_joins from
+ # other homeservers
+ # I wonder if a homeserver-wide collection of rate limiters might be cleaner?
+ self._join_rate_per_room_limiter = Ratelimiter(
+ store=self.store,
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
+ burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
+ )
# Ratelimiter for invites, keyed by room (across all issuers, all
# recipients).
@@ -136,6 +153,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
self.request_ratelimiter = hs.get_request_ratelimiter()
+ hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
+
+ def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
+ """Notify the rate limiter that a room join has occurred.
+
+ Use this to inform the RoomMemberHandler about joins that have either
+ - taken place on another homeserver, or
+ - on another worker in this homeserver.
+ Joins actioned by this worker should use the usual `ratelimit` method, which
+ checks the limit and increments the counter in one go.
+ """
+ self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
@abc.abstractmethod
async def _remote_join(
@@ -396,6 +425,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# up blocking profile updates.
if newly_joined and ratelimit:
await self._join_rate_limiter_local.ratelimit(requester)
+ await self._join_rate_per_room_limiter.ratelimit(
+ requester, key=room_id, update=False
+ )
result_event = await self.event_creation_handler.handle_new_client_event(
requester,
@@ -867,6 +899,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
await self._join_rate_limiter_remote.ratelimit(
requester,
)
+ await self._join_rate_per_room_limiter.ratelimit(
+ requester,
+ key=room_id,
+ update=False,
+ )
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2f59245058..e4f2201c92 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.python.failure import Failure
-from synapse.api.constants import EventTypes, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.federation import send_queue
from synapse.federation.sender import FederationSender
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
@@ -219,6 +219,21 @@ class ReplicationDataHandler:
membership=row.data.membership,
)
+ # If this event is a join, make a note of it so we have an accurate
+ # cross-worker room rate limit.
+ # TODO: Erik said we should exclude rows that came from ex_outliers
+ # here, but I don't see how we can determine that. I guess we could
+ # add a flag to row.data?
+ if (
+ row.data.type == EventTypes.Member
+ and row.data.membership == Membership.JOIN
+ and not row.data.outlier
+ ):
+ # TODO retrieve the previous state, and exclude join -> join transitions
+ self.notifier.notify_user_joined_room(
+ row.data.event_id, row.data.room_id
+ )
+
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 26f4fa7cfd..14b6705862 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -98,6 +98,7 @@ class EventsStreamEventRow(BaseEventsStreamRow):
relates_to: Optional[str]
membership: Optional[str]
rejected: bool
+ outlier: bool
@attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4435373146..5914a35420 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1490,7 +1490,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1506,10 +1506,11 @@ class EventsWorkerStore(SQLBaseStore):
def get_all_new_forward_event_rows(
txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+ " e.outlier"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
@@ -1523,7 +1524,8 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
return cast(
- List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+ txn.fetchall(),
)
return await self.db_pool.runInteraction(
@@ -1532,7 +1534,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1547,11 +1549,14 @@ class EventsWorkerStore(SQLBaseStore):
def get_ex_outlier_stream_rows_txn(
txn: LoggingTransaction,
- ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+ ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+ " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+ " e.outlier"
" FROM events AS e"
+ # NB: the next line (inner join) is what makes this query different from
+ # get_all_new_forward_event_rows.
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
@@ -1566,7 +1571,8 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, instance_name))
return cast(
- List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+ List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+ txn.fetchall(),
)
return await self.db_pool.runInteraction(
|