diff options
author | devonh <devon.dmytro@gmail.com> | 2024-05-21 20:09:17 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-21 20:09:17 +0000 |
commit | 6a9a641fb86b04587840bcb6b76af9a0acef9b54 (patch) | |
tree | a8c0bace62f61de7b605c23635d85d111abf11d9 /tests | |
parent | Improve perf of sync device lists (#17216) (diff) | |
download | synapse-6a9a641fb86b04587840bcb6b76af9a0acef9b54.tar.xz |
Bring auto-accept invite logic into Synapse (#17147)
This PR ports the logic from the [synapse_auto_accept_invite](https://github.com/matrix-org/synapse-auto-accept-invite) module into synapse. I went with the naive approach of injecting the "module" next to where third party modules are currently loaded. If there is a better/preferred way to handle this, I'm all ears. It wasn't obvious to me if there was a better location to add this logic that would cleanly apply to all incoming invite events. Relies on https://github.com/element-hq/synapse/pull/17166 to fix linter errors.
Diffstat (limited to 'tests')
-rw-r--r-- | tests/events/test_auto_accept_invites.py | 657 | ||||
-rw-r--r-- | tests/rest/client/utils.py | 2 | ||||
-rw-r--r-- | tests/server.py | 6 |
3 files changed, 665 insertions, 0 deletions
diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py new file mode 100644 index 0000000000..7fb4d4fa90 --- /dev/null +++ b/tests/events/test_auto_accept_invites.py @@ -0,0 +1,657 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright 2021 The Matrix.org Foundation C.I.C +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# Originally licensed under the Apache License, Version 2.0: +# <http://www.apache.org/licenses/LICENSE-2.0>. +# +# [This file includes modifications made by New Vector Limited] +# +# +import asyncio +from asyncio import Future +from http import HTTPStatus +from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, cast +from unittest.mock import Mock + +import attr +from parameterized import parameterized + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig +from synapse.events.auto_accept_invites import InviteAutoAccepter +from synapse.federation.federation_base import event_from_pdu_json +from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey, SyncVersion +from synapse.module_api import ModuleApi +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import StreamToken, create_requester +from synapse.util import Clock + +from tests.handlers.test_sync import generate_sync_config +from tests.unittest import ( + FederatingHomeserverTestCase, + HomeserverTestCase, + TestCase, + override_config, +) + + +class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase): + """ + Integration test cases for auto-accepting invites. + """ + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + hs = self.setup_test_homeserver() + self.handler = hs.get_federation_handler() + self.store = hs.get_datastores().main + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sync_handler = self.hs.get_sync_handler() + self.module_api = hs.get_module_api() + + @parameterized.expand( + [ + [False], + [True], + ] + ) + @override_config( + { + "auto_accept_invites": { + "enabled": True, + }, + } + ) + def test_auto_accept_invites(self, direct_room: bool) -> None: + """Test that a user automatically joins a room when invited, if the + module is enabled. + """ + # A local user who sends an invite + inviting_user_id = self.register_user("inviter", "pass") + inviting_user_tok = self.login("inviter", "pass") + + # A local user who receives an invite + invited_user_id = self.register_user("invitee", "pass") + self.login("invitee", "pass") + + # Create a room and send an invite to the other user + room_id = self.helper.create_room_as( + inviting_user_id, + is_public=False, + tok=inviting_user_tok, + ) + + self.helper.invite( + room_id, + inviting_user_id, + invited_user_id, + tok=inviting_user_tok, + extra_data={"is_direct": direct_room}, + ) + + # Check that the invite receiving user has automatically joined the room when syncing + join_updates, _ = sync_join(self, invited_user_id) + self.assertEqual(len(join_updates), 1) + + join_update: JoinedSyncResult = join_updates[0] + self.assertEqual(join_update.room_id, room_id) + + @override_config( + { + "auto_accept_invites": { + "enabled": False, + }, + } + ) + def test_module_not_enabled(self) -> None: + """Test that a user does not automatically join a room when invited, + if the module is not enabled. + """ + # A local user who sends an invite + inviting_user_id = self.register_user("inviter", "pass") + inviting_user_tok = self.login("inviter", "pass") + + # A local user who receives an invite + invited_user_id = self.register_user("invitee", "pass") + self.login("invitee", "pass") + + # Create a room and send an invite to the other user + room_id = self.helper.create_room_as( + inviting_user_id, is_public=False, tok=inviting_user_tok + ) + + self.helper.invite( + room_id, + inviting_user_id, + invited_user_id, + tok=inviting_user_tok, + ) + + # Check that the invite receiving user has not automatically joined the room when syncing + join_updates, _ = sync_join(self, invited_user_id) + self.assertEqual(len(join_updates), 0) + + @override_config( + { + "auto_accept_invites": { + "enabled": True, + }, + } + ) + def test_invite_from_remote_user(self) -> None: + """Test that an invite from a remote user results in the invited user + automatically joining the room. + """ + # A remote user who sends the invite + remote_server = "otherserver" + remote_user = "@otheruser:" + remote_server + + # A local user who creates the room + creator_user_id = self.register_user("creator", "pass") + creator_user_tok = self.login("creator", "pass") + + # A local user who receives an invite + invited_user_id = self.register_user("invitee", "pass") + self.login("invitee", "pass") + + room_id = self.helper.create_room_as( + room_creator=creator_user_id, tok=creator_user_tok + ) + room_version = self.get_success(self.store.get_room_version(room_id)) + + invite_event = event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": remote_user, + "state_key": invited_user_id, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + self.get_success( + self.handler.on_invite_request( + remote_server, + invite_event, + invite_event.room_version, + ) + ) + + # Check that the invite receiving user has automatically joined the room when syncing + join_updates, _ = sync_join(self, invited_user_id) + self.assertEqual(len(join_updates), 1) + + join_update: JoinedSyncResult = join_updates[0] + self.assertEqual(join_update.room_id, room_id) + + @parameterized.expand( + [ + [False, False], + [True, True], + ] + ) + @override_config( + { + "auto_accept_invites": { + "enabled": True, + "only_for_direct_messages": True, + }, + } + ) + def test_accept_invite_direct_message( + self, + direct_room: bool, + expect_auto_join: bool, + ) -> None: + """Tests that, if the module is configured to only accept DM invites, invites to DM rooms are still + automatically accepted. Otherwise they are rejected. + """ + # A local user who sends an invite + inviting_user_id = self.register_user("inviter", "pass") + inviting_user_tok = self.login("inviter", "pass") + + # A local user who receives an invite + invited_user_id = self.register_user("invitee", "pass") + self.login("invitee", "pass") + + # Create a room and send an invite to the other user + room_id = self.helper.create_room_as( + inviting_user_id, + is_public=False, + tok=inviting_user_tok, + ) + + self.helper.invite( + room_id, + inviting_user_id, + invited_user_id, + tok=inviting_user_tok, + extra_data={"is_direct": direct_room}, + ) + + if expect_auto_join: + # Check that the invite receiving user has automatically joined the room when syncing + join_updates, _ = sync_join(self, invited_user_id) + self.assertEqual(len(join_updates), 1) + + join_update: JoinedSyncResult = join_updates[0] + self.assertEqual(join_update.room_id, room_id) + else: + # Check that the invite receiving user has not automatically joined the room when syncing + join_updates, _ = sync_join(self, invited_user_id) + self.assertEqual(len(join_updates), 0) + + @parameterized.expand( + [ + [False, True], + [True, False], + ] + ) + @override_config( + { + "auto_accept_invites": { + "enabled": True, + "only_from_local_users": True, + }, + } + ) + def test_accept_invite_local_user( + self, remote_inviter: bool, expect_auto_join: bool + ) -> None: + """Tests that, if the module is configured to only accept invites from local users, invites + from local users are still automatically accepted. Otherwise they are rejected. + """ + # A local user who sends an invite + creator_user_id = self.register_user("inviter", "pass") + creator_user_tok = self.login("inviter", "pass") + + # A local user who receives an invite + invited_user_id = self.register_user("invitee", "pass") + self.login("invitee", "pass") + + # Create a room and send an invite to the other user + room_id = self.helper.create_room_as( + creator_user_id, is_public=False, tok=creator_user_tok + ) + + if remote_inviter: + room_version = self.get_success(self.store.get_room_version(room_id)) + + # A remote user who sends the invite + remote_server = "otherserver" + remote_user = "@otheruser:" + remote_server + + invite_event = event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": remote_user, + "state_key": invited_user_id, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + self.get_success( + self.handler.on_invite_request( + remote_server, + invite_event, + invite_event.room_version, + ) + ) + else: + self.helper.invite( + room_id, + creator_user_id, + invited_user_id, + tok=creator_user_tok, + ) + + if expect_auto_join: + # Check that the invite receiving user has automatically joined the room when syncing + join_updates, _ = sync_join(self, invited_user_id) + self.assertEqual(len(join_updates), 1) + + join_update: JoinedSyncResult = join_updates[0] + self.assertEqual(join_update.room_id, room_id) + else: + # Check that the invite receiving user has not automatically joined the room when syncing + join_updates, _ = sync_join(self, invited_user_id) + self.assertEqual(len(join_updates), 0) + + +_request_key = 0 + + +def generate_request_key() -> SyncRequestKey: + global _request_key + _request_key += 1 + return ("request_key", _request_key) + + +def sync_join( + testcase: HomeserverTestCase, + user_id: str, + since_token: Optional[StreamToken] = None, +) -> Tuple[List[JoinedSyncResult], StreamToken]: + """Perform a sync request for the given user and return the user join updates + they've received, as well as the next_batch token. + + This method assumes testcase.sync_handler points to the homeserver's sync handler. + + Args: + testcase: The testcase that is currently being run. + user_id: The ID of the user to generate a sync response for. + since_token: An optional token to indicate from at what point to sync from. + + Returns: + A tuple containing a list of join updates, and the sync response's + next_batch token. + """ + requester = create_requester(user_id) + sync_config = generate_sync_config(requester.user.to_string()) + sync_result = testcase.get_success( + testcase.hs.get_sync_handler().wait_for_sync_for_user( + requester, + sync_config, + SyncVersion.SYNC_V2, + generate_request_key(), + since_token, + ) + ) + + return sync_result.joined, sync_result.next_batch + + +class InviteAutoAccepterInternalTestCase(TestCase): + """ + Test cases which exercise the internals of the InviteAutoAccepter. + """ + + def setUp(self) -> None: + self.module = create_module() + self.user_id = "@peter:test" + self.invitee = "@lesley:test" + self.remote_invitee = "@thomas:remote" + + # We know our module API is a mock, but mypy doesn't. + self.mocked_update_membership: Mock = self.module._api.update_room_membership # type: ignore[assignment] + + async def test_accept_invite_with_failures(self) -> None: + """Tests that receiving an invite for a local user makes the module attempt to + make the invitee join the room. This test verifies that it works if the call to + update membership returns exceptions before successfully completing and returning an event. + """ + invite = MockEvent( + sender="@inviter:test", + state_key="@invitee:test", + type="m.room.member", + content={"membership": "invite"}, + ) + + join_event = MockEvent( + sender="someone", + state_key="someone", + type="m.room.member", + content={"membership": "join"}, + ) + # the first two calls raise an exception while the third call is successful + self.mocked_update_membership.side_effect = [ + SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"), + SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"), + make_awaitable(join_event), + ] + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + + await self.retry_assertions( + self.mocked_update_membership, + 3, + sender=invite.state_key, + target=invite.state_key, + room_id=invite.room_id, + new_membership="join", + ) + + async def test_accept_invite_failures(self) -> None: + """Tests that receiving an invite for a local user makes the module attempt to + make the invitee join the room. This test verifies that if the update_membership call + fails consistently, _retry_make_join will break the loop after the set number of retries and + execution will continue. + """ + invite = MockEvent( + sender=self.user_id, + state_key=self.invitee, + type="m.room.member", + content={"membership": "invite"}, + ) + self.mocked_update_membership.side_effect = SynapseError( + HTTPStatus.FORBIDDEN, "Forbidden" + ) + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + + await self.retry_assertions( + self.mocked_update_membership, + 5, + sender=invite.state_key, + target=invite.state_key, + room_id=invite.room_id, + new_membership="join", + ) + + async def test_not_state(self) -> None: + """Tests that receiving an invite that's not a state event does nothing.""" + invite = MockEvent( + sender=self.user_id, type="m.room.member", content={"membership": "invite"} + ) + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + + self.mocked_update_membership.assert_not_called() + + async def test_not_invite(self) -> None: + """Tests that receiving a membership update that's not an invite does nothing.""" + invite = MockEvent( + sender=self.user_id, + state_key=self.user_id, + type="m.room.member", + content={"membership": "join"}, + ) + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + + self.mocked_update_membership.assert_not_called() + + async def test_not_membership(self) -> None: + """Tests that receiving a state event that's not a membership update does + nothing. + """ + invite = MockEvent( + sender=self.user_id, + state_key=self.user_id, + type="org.matrix.test", + content={"foo": "bar"}, + ) + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + + self.mocked_update_membership.assert_not_called() + + def test_config_parse(self) -> None: + """Tests that a correct configuration parses.""" + config = { + "auto_accept_invites": { + "enabled": True, + "only_for_direct_messages": True, + "only_from_local_users": True, + } + } + parsed_config = AutoAcceptInvitesConfig() + parsed_config.read_config(config) + + self.assertTrue(parsed_config.enabled) + self.assertTrue(parsed_config.accept_invites_only_for_direct_messages) + self.assertTrue(parsed_config.accept_invites_only_from_local_users) + + def test_runs_on_only_one_worker(self) -> None: + """ + Tests that the module only runs on the specified worker. + """ + # By default, we run on the main process... + main_module = create_module( + config_override={"auto_accept_invites": {"enabled": True}}, worker_name=None + ) + cast( + Mock, main_module._api.register_third_party_rules_callbacks + ).assert_called_once() + + # ...and not on other workers (like synchrotrons)... + sync_module = create_module(worker_name="synchrotron42") + cast( + Mock, sync_module._api.register_third_party_rules_callbacks + ).assert_not_called() + + # ...unless we configured them to be the designated worker. + specified_module = create_module( + config_override={ + "auto_accept_invites": { + "enabled": True, + "worker_to_run_on": "account_data1", + } + }, + worker_name="account_data1", + ) + cast( + Mock, specified_module._api.register_third_party_rules_callbacks + ).assert_called_once() + + async def retry_assertions( + self, mock: Mock, call_count: int, **kwargs: Any + ) -> None: + """ + This is a hacky way to ensure that the assertions are not called before the other coroutine + has a chance to call `update_room_membership`. It catches the exception caused by a failure, + and sleeps the thread before retrying, up until 5 tries. + + Args: + call_count: the number of times the mock should have been called + mock: the mocked function we want to assert on + kwargs: keyword arguments to assert that the mock was called with + """ + + i = 0 + while i < 5: + try: + # Check that the mocked method is called the expected amount of times and with the right + # arguments to attempt to make the user join the room. + mock.assert_called_with(**kwargs) + self.assertEqual(call_count, mock.call_count) + break + except AssertionError as e: + i += 1 + if i == 5: + # we've used up the tries, force the test to fail as we've already caught the exception + self.fail(e) + await asyncio.sleep(1) + + +@attr.s(auto_attribs=True) +class MockEvent: + """Mocks an event. Only exposes properties the module uses.""" + + sender: str + type: str + content: Dict[str, Any] + room_id: str = "!someroom" + state_key: Optional[str] = None + + def is_state(self) -> bool: + """Checks if the event is a state event by checking if it has a state key.""" + return self.state_key is not None + + @property + def membership(self) -> str: + """Extracts the membership from the event. Should only be called on an event + that's a membership event, and will raise a KeyError otherwise. + """ + membership: str = self.content["membership"] + return membership + + +T = TypeVar("T") +TV = TypeVar("TV") + + +async def make_awaitable(value: T) -> T: + return value + + +def make_multiple_awaitable(result: TV) -> Awaitable[TV]: + """ + Makes an awaitable, suitable for mocking an `async` function. + This uses Futures as they can be awaited multiple times so can be returned + to multiple callers. + """ + future: Future[TV] = Future() + future.set_result(result) + return future + + +def create_module( + config_override: Optional[Dict[str, Any]] = None, worker_name: Optional[str] = None +) -> InviteAutoAccepter: + # Create a mock based on the ModuleApi spec, but override some mocked functions + # because some capabilities are needed for running the tests. + module_api = Mock(spec=ModuleApi) + module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test" + module_api.worker_name = worker_name + module_api.sleep.return_value = make_multiple_awaitable(None) + + if config_override is None: + config_override = {} + + config = AutoAcceptInvitesConfig() + config.read_config(config_override) + + return InviteAutoAccepter(config, module_api) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index fe00afe198..7362bde7ab 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -170,6 +170,7 @@ class RestHelper: targ: Optional[str] = None, expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, + extra_data: Optional[dict] = None, ) -> JsonDict: return self.change_membership( room=room, @@ -178,6 +179,7 @@ class RestHelper: tok=tok, membership=Membership.INVITE, expect_code=expect_code, + extra_data=extra_data, ) def join( diff --git a/tests/server.py b/tests/server.py index 434be3d22c..f3a917f835 100644 --- a/tests/server.py +++ b/tests/server.py @@ -85,6 +85,7 @@ from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig +from synapse.events.auto_accept_invites import InviteAutoAccepter from synapse.events.presence_router import load_legacy_presence_router from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.http.site import SynapseRequest @@ -1156,6 +1157,11 @@ def setup_test_homeserver( for module, module_config in hs.config.modules.loaded_modules: module(config=module_config, api=module_api) + if hs.config.auto_accept_invites.enabled: + # Start the local auto_accept_invites module. + m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api) + logger.info("Loaded local module %s", m) + load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) load_legacy_presence_router(hs) |