diff options
-rw-r--r-- | tests/handlers/test_room_member.py | 135 |
1 files changed, 113 insertions, 22 deletions
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index 5c927f0c23..254e7e4b80 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -1,13 +1,16 @@ from http import HTTPStatus -from unittest.mock import patch +from unittest.mock import Mock, patch from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.rest.client.login import synapse.rest.client.room -from synapse.api.constants import Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import LimitExceededError +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import FrozenEventV3 +from synapse.federation.federation_client import SendJoinResult from synapse.server import HomeServer from synapse.types import UserID, create_requester from synapse.util import Clock @@ -15,10 +18,10 @@ from synapse.util import Clock from tests.replication._base import RedisMultiWorkerStreamTestCase from tests.server import make_request from tests.test_utils import make_awaitable -from tests.unittest import HomeserverTestCase, override_config +from tests.unittest import FederatingHomeserverTestCase, override_config -class TestJoinsLimitedByPerRoomRateLimiter(HomeserverTestCase): +class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, synapse.rest.client.login.register_servlets, @@ -36,10 +39,11 @@ class TestJoinsLimitedByPerRoomRateLimiter(HomeserverTestCase): self.chris = self.register_user("chris", "pass") self.chris_token = self.login("chris", "pass") - # Create a room on this homeserver. - # Note that this counts as a + # Create a room on this homeserver. Note that this counts as a join: it + # contributes to the rate limter's count of actions self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) - self.intially_unjoined_room_id = "!example:otherhs" + + self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}" @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None: @@ -92,12 +96,97 @@ class TestJoinsLimitedByPerRoomRateLimiter(HomeserverTestCase): @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}}) def test_remote_joins_contribute_to_rate_limit(self) -> None: - # Join once, to fill the rate limiter bucket. Patch out the `_remote_join" call - # because there is no other homeserver for us to join via. + # Join once, to fill the rate limiter bucket. + # + # To do this we have to mock the responses from the remote homeserver. + # We also patch out a bunch of event checks on our end. All we're really + # trying to check here is that remote joins will bump the rate limter when + # they are persisted. + create_event_source = { + "auth_events": [], + "content": { + "creator": f"@creator:{self.OTHER_SERVER_NAME}", + "room_version": self.hs.config.server.default_room_version.identifier, + }, + "depth": 0, + "origin_server_ts": 0, + "prev_events": [], + "room_id": self.intially_unjoined_room_id, + "sender": f"@creator:{self.OTHER_SERVER_NAME}", + "state_key": "", + "type": EventTypes.Create, + } + self.add_hashes_and_signatures_from_other_server( + create_event_source, + self.hs.config.server.default_room_version, + ) + create_event = FrozenEventV3( + create_event_source, + self.hs.config.server.default_room_version, + {}, + None, + ) + + join_event_source = { + "auth_events": [create_event.event_id], + "content": {"membership": "join"}, + "depth": 1, + "origin_server_ts": 100, + "prev_events": [create_event.event_id], + "sender": self.bob, + "state_key": self.bob, + "room_id": self.intially_unjoined_room_id, + "type": EventTypes.Member, + } + add_hashes_and_signatures( + self.hs.config.server.default_room_version, + join_event_source, + self.hs.hostname, + self.hs.signing_key, + ) + join_event = FrozenEventV3( + join_event_source, + self.hs.config.server.default_room_version, + {}, + None, + ) + + mock_make_membership_event = Mock( + return_value=make_awaitable( + ( + self.OTHER_SERVER_NAME, + join_event, + self.hs.config.server.default_room_version, + ) + ) + ) + mock_send_join = Mock( + return_value=make_awaitable( + SendJoinResult( + join_event, + self.OTHER_SERVER_NAME, + state=[create_event], + auth_chain=[create_event], + partial_state=False, + servers_in_room=[], + ) + ) + ) + with patch.object( - self.handler, - "_remote_join", - return_value=make_awaitable(("$dummy_event", 1000)), + self.handler.federation_handler.federation_client, + "make_membership_event", + mock_make_membership_event, + ), patch.object( + self.handler.federation_handler.federation_client, + "send_join", + mock_send_join, + ), patch( + "synapse.event_auth._is_membership_change_allowed", + return_value=None, + ), patch( + "synapse.handlers.federation_event.check_state_dependent_auth_rules", + return_value=None, ): self.get_success( self.handler.update_membership( @@ -105,19 +194,21 @@ class TestJoinsLimitedByPerRoomRateLimiter(HomeserverTestCase): target=UserID.from_string(self.bob), room_id=self.intially_unjoined_room_id, action=Membership.JOIN, + remote_room_hosts=[self.OTHER_SERVER_NAME], ) ) - # Try to join as Chris. Should get denied. - self.get_failure( - self.handler.update_membership( - requester=create_requester(self.chris), - target=UserID.from_string(self.chris), - room_id=self.intially_unjoined_room_id, - action=Membership.JOIN, - ), - LimitExceededError, - ) + # Try to join as Chris. Should get denied. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.intially_unjoined_room_id, + action=Membership.JOIN, + remote_room_hosts=[self.OTHER_SERVER_NAME], + ), + LimitExceededError, + ) # TODO: test that remote joins to a room are rate limited. # Could do this by setting the burst count to 1, then: |