summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-07-19 12:45:17 +0100
committerGitHub <noreply@github.com>2022-07-19 11:45:17 +0000
commitb9778673587941277e15b067ad39cdf084f7dde5 (patch)
tree1130f92b1869a63305aa4ceb1a12d540432f85fd /tests/handlers
parentSafe async event cache (#13308) (diff)
downloadsynapse-b9778673587941277e15b067ad39cdf084f7dde5.tar.xz
Rate limit joins per-room (#13276)
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_room_member.py290
1 files changed, 290 insertions, 0 deletions
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
new file mode 100644
index 0000000000..254e7e4b80
--- /dev/null
+++ b/tests/handlers/test_room_member.py
@@ -0,0 +1,290 @@
+from http import HTTPStatus
+from unittest.mock import Mock, patch
+
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+import synapse.rest.client.login
+import synapse.rest.client.room
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import LimitExceededError
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.events import FrozenEventV3
+from synapse.federation.federation_client import SendJoinResult
+from synapse.server import HomeServer
+from synapse.types import UserID, create_requester
+from synapse.util import Clock
+
+from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.server import make_request
+from tests.test_utils import make_awaitable
+from tests.unittest import FederatingHomeserverTestCase, override_config
+
+
+class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        synapse.rest.client.login.register_servlets,
+        synapse.rest.client.room.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.handler = hs.get_room_member_handler()
+
+        # Create three users.
+        self.alice = self.register_user("alice", "pass")
+        self.alice_token = self.login("alice", "pass")
+        self.bob = self.register_user("bob", "pass")
+        self.bob_token = self.login("bob", "pass")
+        self.chris = self.register_user("chris", "pass")
+        self.chris_token = self.login("chris", "pass")
+
+        # Create a room on this homeserver. Note that this counts as a join: it
+        # contributes to the rate limter's count of actions
+        self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+        self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
+
+    @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
+    def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None:
+        # The rate limiter has accumulated one token from Alice's join after the create
+        # event.
+        # Try joining the room as Bob.
+        self.get_success(
+            self.handler.update_membership(
+                requester=create_requester(self.bob),
+                target=UserID.from_string(self.bob),
+                room_id=self.room_id,
+                action=Membership.JOIN,
+            )
+        )
+
+        # The rate limiter bucket is full. A second join should be denied.
+        self.get_failure(
+            self.handler.update_membership(
+                requester=create_requester(self.chris),
+                target=UserID.from_string(self.chris),
+                room_id=self.room_id,
+                action=Membership.JOIN,
+            ),
+            LimitExceededError,
+        )
+
+    @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
+    def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None:
+        # The rate limiter has accumulated one token from Alice's join after the create
+        # event. Alice should still be able to change her displayname.
+        self.get_success(
+            self.handler.update_membership(
+                requester=create_requester(self.alice),
+                target=UserID.from_string(self.alice),
+                room_id=self.room_id,
+                action=Membership.JOIN,
+                content={"displayname": "Alice Cooper"},
+            )
+        )
+
+        # Still room in the limiter bucket. Chris's join should be accepted.
+        self.get_success(
+            self.handler.update_membership(
+                requester=create_requester(self.chris),
+                target=UserID.from_string(self.chris),
+                room_id=self.room_id,
+                action=Membership.JOIN,
+            )
+        )
+
+    @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}})
+    def test_remote_joins_contribute_to_rate_limit(self) -> None:
+        # Join once, to fill the rate limiter bucket.
+        #
+        # To do this we have to mock the responses from the remote homeserver.
+        # We also patch out a bunch of event checks on our end. All we're really
+        # trying to check here is that remote joins will bump the rate limter when
+        # they are persisted.
+        create_event_source = {
+            "auth_events": [],
+            "content": {
+                "creator": f"@creator:{self.OTHER_SERVER_NAME}",
+                "room_version": self.hs.config.server.default_room_version.identifier,
+            },
+            "depth": 0,
+            "origin_server_ts": 0,
+            "prev_events": [],
+            "room_id": self.intially_unjoined_room_id,
+            "sender": f"@creator:{self.OTHER_SERVER_NAME}",
+            "state_key": "",
+            "type": EventTypes.Create,
+        }
+        self.add_hashes_and_signatures_from_other_server(
+            create_event_source,
+            self.hs.config.server.default_room_version,
+        )
+        create_event = FrozenEventV3(
+            create_event_source,
+            self.hs.config.server.default_room_version,
+            {},
+            None,
+        )
+
+        join_event_source = {
+            "auth_events": [create_event.event_id],
+            "content": {"membership": "join"},
+            "depth": 1,
+            "origin_server_ts": 100,
+            "prev_events": [create_event.event_id],
+            "sender": self.bob,
+            "state_key": self.bob,
+            "room_id": self.intially_unjoined_room_id,
+            "type": EventTypes.Member,
+        }
+        add_hashes_and_signatures(
+            self.hs.config.server.default_room_version,
+            join_event_source,
+            self.hs.hostname,
+            self.hs.signing_key,
+        )
+        join_event = FrozenEventV3(
+            join_event_source,
+            self.hs.config.server.default_room_version,
+            {},
+            None,
+        )
+
+        mock_make_membership_event = Mock(
+            return_value=make_awaitable(
+                (
+                    self.OTHER_SERVER_NAME,
+                    join_event,
+                    self.hs.config.server.default_room_version,
+                )
+            )
+        )
+        mock_send_join = Mock(
+            return_value=make_awaitable(
+                SendJoinResult(
+                    join_event,
+                    self.OTHER_SERVER_NAME,
+                    state=[create_event],
+                    auth_chain=[create_event],
+                    partial_state=False,
+                    servers_in_room=[],
+                )
+            )
+        )
+
+        with patch.object(
+            self.handler.federation_handler.federation_client,
+            "make_membership_event",
+            mock_make_membership_event,
+        ), patch.object(
+            self.handler.federation_handler.federation_client,
+            "send_join",
+            mock_send_join,
+        ), patch(
+            "synapse.event_auth._is_membership_change_allowed",
+            return_value=None,
+        ), patch(
+            "synapse.handlers.federation_event.check_state_dependent_auth_rules",
+            return_value=None,
+        ):
+            self.get_success(
+                self.handler.update_membership(
+                    requester=create_requester(self.bob),
+                    target=UserID.from_string(self.bob),
+                    room_id=self.intially_unjoined_room_id,
+                    action=Membership.JOIN,
+                    remote_room_hosts=[self.OTHER_SERVER_NAME],
+                )
+            )
+
+            # Try to join as Chris. Should get denied.
+            self.get_failure(
+                self.handler.update_membership(
+                    requester=create_requester(self.chris),
+                    target=UserID.from_string(self.chris),
+                    room_id=self.intially_unjoined_room_id,
+                    action=Membership.JOIN,
+                    remote_room_hosts=[self.OTHER_SERVER_NAME],
+                ),
+                LimitExceededError,
+            )
+
+    # TODO: test that remote joins to a room are rate limited.
+    #   Could do this by setting the burst count to 1, then:
+    #   - remote-joining a room
+    #   - immediately leaving
+    #   - trying to remote-join again.
+
+
+class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        synapse.rest.client.login.register_servlets,
+        synapse.rest.client.room.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.handler = hs.get_room_member_handler()
+
+        # Create three users.
+        self.alice = self.register_user("alice", "pass")
+        self.alice_token = self.login("alice", "pass")
+        self.bob = self.register_user("bob", "pass")
+        self.bob_token = self.login("bob", "pass")
+        self.chris = self.register_user("chris", "pass")
+        self.chris_token = self.login("chris", "pass")
+
+        # Create a room on this homeserver.
+        # Note that this counts as a
+        self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+        self.intially_unjoined_room_id = "!example:otherhs"
+
+    @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
+    def test_local_users_joining_on_another_worker_contribute_to_rate_limit(
+        self,
+    ) -> None:
+        # The rate limiter has accumulated one token from Alice's join after the create
+        # event.
+        self.replicate()
+
+        # Spawn another worker and have bob join via it.
+        worker_app = self.make_worker_hs(
+            "synapse.app.generic_worker", extra_config={"worker_name": "other worker"}
+        )
+        worker_site = self._hs_to_site[worker_app]
+        channel = make_request(
+            self.reactor,
+            worker_site,
+            "POST",
+            f"/_matrix/client/v3/rooms/{self.room_id}/join",
+            access_token=self.bob_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+        # wait for join to arrive over replication
+        self.replicate()
+
+        # Try to join as Chris on the worker. Should get denied because Alice
+        # and Bob have both joined the room.
+        self.get_failure(
+            worker_app.get_room_member_handler().update_membership(
+                requester=create_requester(self.chris),
+                target=UserID.from_string(self.chris),
+                room_id=self.room_id,
+                action=Membership.JOIN,
+            ),
+            LimitExceededError,
+        )
+
+        # Try to join as Chris on the original worker. Should get denied because Alice
+        # and Bob have both joined the room.
+        self.get_failure(
+            self.handler.update_membership(
+                requester=create_requester(self.chris),
+                target=UserID.from_string(self.chris),
+                room_id=self.room_id,
+                action=Membership.JOIN,
+            ),
+            LimitExceededError,
+        )