summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-07-19 12:45:17 +0100
committerGitHub <noreply@github.com>2022-07-19 11:45:17 +0000
commitb9778673587941277e15b067ad39cdf084f7dde5 (patch)
tree1130f92b1869a63305aa4ceb1a12d540432f85fd /synapse
parentSafe async event cache (#13308) (diff)
downloadsynapse-b9778673587941277e15b067ad39cdf084f7dde5.tar.xz
Rate limit joins per-room (#13276)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/ratelimiting.py7
-rw-r--r--synapse/federation/federation_server.py16
-rw-r--r--synapse/handlers/federation_event.py4
-rw-r--r--synapse/handlers/message.py11
-rw-r--r--synapse/handlers/room_member.py37
-rw-r--r--synapse/replication/tcp/client.py17
-rw-r--r--synapse/replication/tcp/streams/events.py1
-rw-r--r--synapse/storage/databases/main/events_worker.py22
8 files changed, 106 insertions, 9 deletions
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4fc1784efe..5a91917b4a 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -112,6 +112,13 @@ class RatelimitConfig(Config):
             defaults={"per_second": 0.01, "burst_count": 10},
         )
 
+        # Track the rate of joins to a given room. If there are too many, temporarily
+        # prevent local joins and remote joins via this server.
+        self.rc_joins_per_room = RateLimitConfig(
+            config.get("rc_joins_per_room", {}),
+            defaults={"per_second": 1, "burst_count": 10},
+        )
+
         # Ratelimit cross-user key requests:
         # * For local requests this is keyed by the sending device.
         # * For requests received over federation this is keyed by the origin.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5dfdc86740..ae550d3f4d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -118,6 +118,7 @@ class FederationServer(FederationBase):
         self._federation_event_handler = hs.get_federation_event_handler()
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
+        self._room_member_handler = hs.get_room_member_handler()
 
         self._state_storage_controller = hs.get_storage_controllers().state
 
@@ -621,6 +622,15 @@ class FederationServer(FederationBase):
             )
             raise IncompatibleRoomVersionError(room_version=room_version)
 
+        # Refuse the request if that room has seen too many joins recently.
+        # This is in addition to the HS-level rate limiting applied by
+        # BaseFederationServlet.
+        # type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
+        await self._room_member_handler._join_rate_per_room_limiter.ratelimit(  # type: ignore[has-type]
+            requester=None,
+            key=room_id,
+            update=False,
+        )
         pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
         return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
 
@@ -655,6 +665,12 @@ class FederationServer(FederationBase):
         room_id: str,
         caller_supports_partial_state: bool = False,
     ) -> Dict[str, Any]:
+        await self._room_member_handler._join_rate_per_room_limiter.ratelimit(  # type: ignore[has-type]
+            requester=None,
+            key=room_id,
+            update=False,
+        )
+
         event, context = await self._on_send_membership_event(
             origin, content, Membership.JOIN, room_id
         )
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index b1dab57447..766d9849f5 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1980,6 +1980,10 @@ class FederationEventHandler:
             event, event_pos, max_stream_token, extra_users=extra_users
         )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            # TODO retrieve the previous state, and exclude join -> join transitions
+            self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
     def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 85abe71ea8..bd7baef051 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -463,6 +463,7 @@ class EventCreationHandler:
         )
         self._events_shard_config = self.config.worker.events_shard_config
         self._instance_name = hs.get_instance_name()
+        self._notifier = hs.get_notifier()
 
         self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
 
@@ -1550,6 +1551,16 @@ class EventCreationHandler:
                 requester, is_admin_redaction=is_admin_redaction
             )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            (
+                current_membership,
+                _,
+            ) = await self.store.get_local_current_membership_for_user_in_room(
+                event.state_key, event.room_id
+            )
+            if current_membership != Membership.JOIN:
+                self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
         await self._maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a5b9ac904e..30b4cb23df 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -94,12 +94,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
         )
+        # Tracks joins from local users to rooms this server isn't a member of.
+        # I.e. joins this server makes by requesting /make_join /send_join from
+        # another server.
         self._join_rate_limiter_remote = Ratelimiter(
             store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
+        # TODO: find a better place to keep this Ratelimiter.
+        #   It needs to be
+        #    - written to by event persistence code
+        #    - written to by something which can snoop on replication streams
+        #    - read by the RoomMemberHandler to rate limit joins from local users
+        #    - read by the FederationServer to rate limit make_joins and send_joins from
+        #      other homeservers
+        #   I wonder if a homeserver-wide collection of rate limiters might be cleaner?
+        self._join_rate_per_room_limiter = Ratelimiter(
+            store=self.store,
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
+            burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
+        )
 
         # Ratelimiter for invites, keyed by room (across all issuers, all
         # recipients).
@@ -136,6 +153,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         )
 
         self.request_ratelimiter = hs.get_request_ratelimiter()
+        hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
+
+    def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
+        """Notify the rate limiter that a room join has occurred.
+
+        Use this to inform the RoomMemberHandler about joins that have either
+        - taken place on another homeserver, or
+        - on another worker in this homeserver.
+        Joins actioned by this worker should use the usual `ratelimit` method, which
+        checks the limit and increments the counter in one go.
+        """
+        self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
 
     @abc.abstractmethod
     async def _remote_join(
@@ -396,6 +425,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # up blocking profile updates.
             if newly_joined and ratelimit:
                 await self._join_rate_limiter_local.ratelimit(requester)
+                await self._join_rate_per_room_limiter.ratelimit(
+                    requester, key=room_id, update=False
+                )
 
         result_event = await self.event_creation_handler.handle_new_client_event(
             requester,
@@ -867,6 +899,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     await self._join_rate_limiter_remote.ratelimit(
                         requester,
                     )
+                    await self._join_rate_per_room_limiter.ratelimit(
+                        requester,
+                        key=room_id,
+                        update=False,
+                    )
 
                 inviter = await self._get_inviter(target.to_string(), room_id)
                 if inviter and not self.hs.is_mine(inviter):
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2f59245058..e4f2201c92 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector
 from twisted.internet.protocol import ReconnectingClientFactory
 from twisted.python.failure import Failure
 
-from synapse.api.constants import EventTypes, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
 from synapse.federation import send_queue
 from synapse.federation.sender import FederationSender
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
@@ -219,6 +219,21 @@ class ReplicationDataHandler:
                     membership=row.data.membership,
                 )
 
+                # If this event is a join, make a note of it so we have an accurate
+                # cross-worker room rate limit.
+                # TODO: Erik said we should exclude rows that came from ex_outliers
+                #  here, but I don't see how we can determine that. I guess we could
+                #  add a flag to row.data?
+                if (
+                    row.data.type == EventTypes.Member
+                    and row.data.membership == Membership.JOIN
+                    and not row.data.outlier
+                ):
+                    # TODO retrieve the previous state, and exclude join -> join transitions
+                    self.notifier.notify_user_joined_room(
+                        row.data.event_id, row.data.room_id
+                    )
+
         await self._presence_handler.process_replication_rows(
             stream_name, instance_name, token, rows
         )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 26f4fa7cfd..14b6705862 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -98,6 +98,7 @@ class EventsStreamEventRow(BaseEventsStreamRow):
     relates_to: Optional[str]
     membership: Optional[str]
     rejected: bool
+    outlier: bool
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4435373146..5914a35420 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1490,7 +1490,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1506,10 +1506,11 @@ class EventsWorkerStore(SQLBaseStore):
 
         def get_all_new_forward_event_rows(
             txn: LoggingTransaction,
-        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+                " e.outlier"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events AS se USING (event_id)"
@@ -1523,7 +1524,8 @@ class EventsWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (last_id, current_id, instance_name, limit))
             return cast(
-                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+                List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+                txn.fetchall(),
             )
 
         return await self.db_pool.runInteraction(
@@ -1532,7 +1534,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_ex_outlier_stream_rows(
         self, instance_name: str, last_id: int, current_id: int
-    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
@@ -1547,11 +1549,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         def get_ex_outlier_stream_rows_txn(
             txn: LoggingTransaction,
-        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+                " e.outlier"
                 " FROM events AS e"
+                # NB: the next line (inner join) is what makes this query different from
+                # get_all_new_forward_event_rows.
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events AS se USING (event_id)"
@@ -1566,7 +1571,8 @@ class EventsWorkerStore(SQLBaseStore):
 
             txn.execute(sql, (last_id, current_id, instance_name))
             return cast(
-                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+                List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+                txn.fetchall(),
             )
 
         return await self.db_pool.runInteraction(