diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/config/ratelimiting.py | 7 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 16 | ||||
-rw-r--r-- | synapse/handlers/federation_event.py | 4 | ||||
-rw-r--r-- | synapse/handlers/message.py | 11 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 37 | ||||
-rw-r--r-- | synapse/replication/tcp/client.py | 17 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 1 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 22 |
8 files changed, 106 insertions, 9 deletions
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 4fc1784efe..5a91917b4a 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -112,6 +112,13 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 10}, ) + # Track the rate of joins to a given room. If there are too many, temporarily + # prevent local joins and remote joins via this server. + self.rc_joins_per_room = RateLimitConfig( + config.get("rc_joins_per_room", {}), + defaults={"per_second": 1, "burst_count": 10}, + ) + # Ratelimit cross-user key requests: # * For local requests this is keyed by the sending device. # * For requests received over federation this is keyed by the origin. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5dfdc86740..ae550d3f4d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -118,6 +118,7 @@ class FederationServer(FederationBase): self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() + self._room_member_handler = hs.get_room_member_handler() self._state_storage_controller = hs.get_storage_controllers().state @@ -621,6 +622,15 @@ class FederationServer(FederationBase): ) raise IncompatibleRoomVersionError(room_version=room_version) + # Refuse the request if that room has seen too many joins recently. + # This is in addition to the HS-level rate limiting applied by + # BaseFederationServlet. + # type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?) + await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type] + requester=None, + key=room_id, + update=False, + ) pdu = await self.handler.on_make_join_request(origin, room_id, user_id) return {"event": pdu.get_templated_pdu_json(), "room_version": room_version} @@ -655,6 +665,12 @@ class FederationServer(FederationBase): room_id: str, caller_supports_partial_state: bool = False, ) -> Dict[str, Any]: + await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type] + requester=None, + key=room_id, + update=False, + ) + event, context = await self._on_send_membership_event( origin, content, Membership.JOIN, room_id ) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index b1dab57447..766d9849f5 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1980,6 +1980,10 @@ class FederationEventHandler: event, event_pos, max_stream_token, extra_users=extra_users ) + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + # TODO retrieve the previous state, and exclude join -> join transitions + self._notifier.notify_user_joined_room(event.event_id, event.room_id) + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 85abe71ea8..bd7baef051 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -463,6 +463,7 @@ class EventCreationHandler: ) self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() + self._notifier = hs.get_notifier() self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state @@ -1550,6 +1551,16 @@ class EventCreationHandler: requester, is_admin_redaction=is_admin_redaction ) + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + event.state_key, event.room_id + ) + if current_membership != Membership.JOIN: + self._notifier.notify_user_joined_room(event.event_id, event.room_id) + await self._maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a5b9ac904e..30b4cb23df 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -94,12 +94,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) + # Tracks joins from local users to rooms this server isn't a member of. + # I.e. joins this server makes by requesting /make_join /send_join from + # another server. self._join_rate_limiter_remote = Ratelimiter( store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + # TODO: find a better place to keep this Ratelimiter. + # It needs to be + # - written to by event persistence code + # - written to by something which can snoop on replication streams + # - read by the RoomMemberHandler to rate limit joins from local users + # - read by the FederationServer to rate limit make_joins and send_joins from + # other homeservers + # I wonder if a homeserver-wide collection of rate limiters might be cleaner? + self._join_rate_per_room_limiter = Ratelimiter( + store=self.store, + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count, + ) # Ratelimiter for invites, keyed by room (across all issuers, all # recipients). @@ -136,6 +153,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) self.request_ratelimiter = hs.get_request_ratelimiter() + hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room) + + def _on_user_joined_room(self, event_id: str, room_id: str) -> None: + """Notify the rate limiter that a room join has occurred. + + Use this to inform the RoomMemberHandler about joins that have either + - taken place on another homeserver, or + - on another worker in this homeserver. + Joins actioned by this worker should use the usual `ratelimit` method, which + checks the limit and increments the counter in one go. + """ + self._join_rate_per_room_limiter.record_action(requester=None, key=room_id) @abc.abstractmethod async def _remote_join( @@ -396,6 +425,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # up blocking profile updates. if newly_joined and ratelimit: await self._join_rate_limiter_local.ratelimit(requester) + await self._join_rate_per_room_limiter.ratelimit( + requester, key=room_id, update=False + ) result_event = await self.event_creation_handler.handle_new_client_event( requester, @@ -867,6 +899,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): await self._join_rate_limiter_remote.ratelimit( requester, ) + await self._join_rate_per_room_limiter.ratelimit( + requester, + key=room_id, + update=False, + ) inviter = await self._get_inviter(target.to_string(), room_id) if inviter and not self.hs.is_mine(inviter): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2f59245058..e4f2201c92 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.protocol import ReconnectingClientFactory from twisted.python.failure import Failure -from synapse.api.constants import EventTypes, ReceiptTypes +from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.federation import send_queue from synapse.federation.sender import FederationSender from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable @@ -219,6 +219,21 @@ class ReplicationDataHandler: membership=row.data.membership, ) + # If this event is a join, make a note of it so we have an accurate + # cross-worker room rate limit. + # TODO: Erik said we should exclude rows that came from ex_outliers + # here, but I don't see how we can determine that. I guess we could + # add a flag to row.data? + if ( + row.data.type == EventTypes.Member + and row.data.membership == Membership.JOIN + and not row.data.outlier + ): + # TODO retrieve the previous state, and exclude join -> join transitions + self.notifier.notify_user_joined_room( + row.data.event_id, row.data.room_id + ) + await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 26f4fa7cfd..14b6705862 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -98,6 +98,7 @@ class EventsStreamEventRow(BaseEventsStreamRow): relates_to: Optional[str] membership: Optional[str] rejected: bool + outlier: bool @attr.s(slots=True, frozen=True, auto_attribs=True) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4435373146..5914a35420 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1490,7 +1490,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: """Returns new events, for the Events replication stream Args: @@ -1506,10 +1506,11 @@ class EventsWorkerStore(SQLBaseStore): def get_all_new_forward_event_rows( txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," + " e.outlier" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events AS se USING (event_id)" @@ -1523,7 +1524,8 @@ class EventsWorkerStore(SQLBaseStore): ) txn.execute(sql, (last_id, current_id, instance_name, limit)) return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + List[Tuple[int, str, str, str, str, str, str, str, bool, bool]], + txn.fetchall(), ) return await self.db_pool.runInteraction( @@ -1532,7 +1534,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: """Returns de-outliered events, for the Events replication stream Args: @@ -1547,11 +1549,14 @@ class EventsWorkerStore(SQLBaseStore): def get_ex_outlier_stream_rows_txn( txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," + " e.outlier" " FROM events AS e" + # NB: the next line (inner join) is what makes this query different from + # get_all_new_forward_event_rows. " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events AS se USING (event_id)" @@ -1566,7 +1571,8 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, instance_name)) return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + List[Tuple[int, str, str, str, str, str, str, str, bool, bool]], + txn.fetchall(), ) return await self.db_pool.runInteraction( |