diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py
new file mode 100644
index 0000000000..7fb4d4fa90
--- /dev/null
+++ b/tests/events/test_auto_accept_invites.py
@@ -0,0 +1,657 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2021 The Matrix.org Foundation C.I.C
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import asyncio
+from asyncio import Future
+from http import HTTPStatus
+from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, cast
+from unittest.mock import Mock
+
+import attr
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
+from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
+from synapse.events.auto_accept_invites import InviteAutoAccepter
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey, SyncVersion
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import StreamToken, create_requester
+from synapse.util import Clock
+
+from tests.handlers.test_sync import generate_sync_config
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ TestCase,
+ override_config,
+)
+
+
+class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
+ """
+ Integration test cases for auto-accepting invites.
+ """
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ hs = self.setup_test_homeserver()
+ self.handler = hs.get_federation_handler()
+ self.store = hs.get_datastores().main
+ return hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.sync_handler = self.hs.get_sync_handler()
+ self.module_api = hs.get_module_api()
+
+ @parameterized.expand(
+ [
+ [False],
+ [True],
+ ]
+ )
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ },
+ }
+ )
+ def test_auto_accept_invites(self, direct_room: bool) -> None:
+ """Test that a user automatically joins a room when invited, if the
+ module is enabled.
+ """
+ # A local user who sends an invite
+ inviting_user_id = self.register_user("inviter", "pass")
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id,
+ is_public=False,
+ tok=inviting_user_tok,
+ )
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ extra_data={"is_direct": direct_room},
+ )
+
+ # Check that the invite receiving user has automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 1)
+
+ join_update: JoinedSyncResult = join_updates[0]
+ self.assertEqual(join_update.room_id, room_id)
+
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": False,
+ },
+ }
+ )
+ def test_module_not_enabled(self) -> None:
+ """Test that a user does not automatically join a room when invited,
+ if the module is not enabled.
+ """
+ # A local user who sends an invite
+ inviting_user_id = self.register_user("inviter", "pass")
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id, is_public=False, tok=inviting_user_tok
+ )
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ )
+
+ # Check that the invite receiving user has not automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 0)
+
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ },
+ }
+ )
+ def test_invite_from_remote_user(self) -> None:
+ """Test that an invite from a remote user results in the invited user
+ automatically joining the room.
+ """
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ # A local user who creates the room
+ creator_user_id = self.register_user("creator", "pass")
+ creator_user_tok = self.login("creator", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ room_id = self.helper.create_room_as(
+ room_creator=creator_user_id, tok=creator_user_tok
+ )
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": invited_user_id,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+ self.get_success(
+ self.handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ )
+ )
+
+ # Check that the invite receiving user has automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 1)
+
+ join_update: JoinedSyncResult = join_updates[0]
+ self.assertEqual(join_update.room_id, room_id)
+
+ @parameterized.expand(
+ [
+ [False, False],
+ [True, True],
+ ]
+ )
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ "only_for_direct_messages": True,
+ },
+ }
+ )
+ def test_accept_invite_direct_message(
+ self,
+ direct_room: bool,
+ expect_auto_join: bool,
+ ) -> None:
+ """Tests that, if the module is configured to only accept DM invites, invites to DM rooms are still
+ automatically accepted. Otherwise they are rejected.
+ """
+ # A local user who sends an invite
+ inviting_user_id = self.register_user("inviter", "pass")
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id,
+ is_public=False,
+ tok=inviting_user_tok,
+ )
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ extra_data={"is_direct": direct_room},
+ )
+
+ if expect_auto_join:
+ # Check that the invite receiving user has automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 1)
+
+ join_update: JoinedSyncResult = join_updates[0]
+ self.assertEqual(join_update.room_id, room_id)
+ else:
+ # Check that the invite receiving user has not automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 0)
+
+ @parameterized.expand(
+ [
+ [False, True],
+ [True, False],
+ ]
+ )
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ "only_from_local_users": True,
+ },
+ }
+ )
+ def test_accept_invite_local_user(
+ self, remote_inviter: bool, expect_auto_join: bool
+ ) -> None:
+ """Tests that, if the module is configured to only accept invites from local users, invites
+ from local users are still automatically accepted. Otherwise they are rejected.
+ """
+ # A local user who sends an invite
+ creator_user_id = self.register_user("inviter", "pass")
+ creator_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+ self.login("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ creator_user_id, is_public=False, tok=creator_user_tok
+ )
+
+ if remote_inviter:
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": invited_user_id,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+ self.get_success(
+ self.handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ )
+ )
+ else:
+ self.helper.invite(
+ room_id,
+ creator_user_id,
+ invited_user_id,
+ tok=creator_user_tok,
+ )
+
+ if expect_auto_join:
+ # Check that the invite receiving user has automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 1)
+
+ join_update: JoinedSyncResult = join_updates[0]
+ self.assertEqual(join_update.room_id, room_id)
+ else:
+ # Check that the invite receiving user has not automatically joined the room when syncing
+ join_updates, _ = sync_join(self, invited_user_id)
+ self.assertEqual(len(join_updates), 0)
+
+
+_request_key = 0
+
+
+def generate_request_key() -> SyncRequestKey:
+ global _request_key
+ _request_key += 1
+ return ("request_key", _request_key)
+
+
+def sync_join(
+ testcase: HomeserverTestCase,
+ user_id: str,
+ since_token: Optional[StreamToken] = None,
+) -> Tuple[List[JoinedSyncResult], StreamToken]:
+ """Perform a sync request for the given user and return the user join updates
+ they've received, as well as the next_batch token.
+
+ This method assumes testcase.sync_handler points to the homeserver's sync handler.
+
+ Args:
+ testcase: The testcase that is currently being run.
+ user_id: The ID of the user to generate a sync response for.
+ since_token: An optional token to indicate from at what point to sync from.
+
+ Returns:
+ A tuple containing a list of join updates, and the sync response's
+ next_batch token.
+ """
+ requester = create_requester(user_id)
+ sync_config = generate_sync_config(requester.user.to_string())
+ sync_result = testcase.get_success(
+ testcase.hs.get_sync_handler().wait_for_sync_for_user(
+ requester,
+ sync_config,
+ SyncVersion.SYNC_V2,
+ generate_request_key(),
+ since_token,
+ )
+ )
+
+ return sync_result.joined, sync_result.next_batch
+
+
+class InviteAutoAccepterInternalTestCase(TestCase):
+ """
+ Test cases which exercise the internals of the InviteAutoAccepter.
+ """
+
+ def setUp(self) -> None:
+ self.module = create_module()
+ self.user_id = "@peter:test"
+ self.invitee = "@lesley:test"
+ self.remote_invitee = "@thomas:remote"
+
+ # We know our module API is a mock, but mypy doesn't.
+ self.mocked_update_membership: Mock = self.module._api.update_room_membership # type: ignore[assignment]
+
+ async def test_accept_invite_with_failures(self) -> None:
+ """Tests that receiving an invite for a local user makes the module attempt to
+ make the invitee join the room. This test verifies that it works if the call to
+ update membership returns exceptions before successfully completing and returning an event.
+ """
+ invite = MockEvent(
+ sender="@inviter:test",
+ state_key="@invitee:test",
+ type="m.room.member",
+ content={"membership": "invite"},
+ )
+
+ join_event = MockEvent(
+ sender="someone",
+ state_key="someone",
+ type="m.room.member",
+ content={"membership": "join"},
+ )
+ # the first two calls raise an exception while the third call is successful
+ self.mocked_update_membership.side_effect = [
+ SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
+ SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
+ make_awaitable(join_event),
+ ]
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ await self.retry_assertions(
+ self.mocked_update_membership,
+ 3,
+ sender=invite.state_key,
+ target=invite.state_key,
+ room_id=invite.room_id,
+ new_membership="join",
+ )
+
+ async def test_accept_invite_failures(self) -> None:
+ """Tests that receiving an invite for a local user makes the module attempt to
+ make the invitee join the room. This test verifies that if the update_membership call
+ fails consistently, _retry_make_join will break the loop after the set number of retries and
+ execution will continue.
+ """
+ invite = MockEvent(
+ sender=self.user_id,
+ state_key=self.invitee,
+ type="m.room.member",
+ content={"membership": "invite"},
+ )
+ self.mocked_update_membership.side_effect = SynapseError(
+ HTTPStatus.FORBIDDEN, "Forbidden"
+ )
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ await self.retry_assertions(
+ self.mocked_update_membership,
+ 5,
+ sender=invite.state_key,
+ target=invite.state_key,
+ room_id=invite.room_id,
+ new_membership="join",
+ )
+
+ async def test_not_state(self) -> None:
+ """Tests that receiving an invite that's not a state event does nothing."""
+ invite = MockEvent(
+ sender=self.user_id, type="m.room.member", content={"membership": "invite"}
+ )
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ self.mocked_update_membership.assert_not_called()
+
+ async def test_not_invite(self) -> None:
+ """Tests that receiving a membership update that's not an invite does nothing."""
+ invite = MockEvent(
+ sender=self.user_id,
+ state_key=self.user_id,
+ type="m.room.member",
+ content={"membership": "join"},
+ )
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ self.mocked_update_membership.assert_not_called()
+
+ async def test_not_membership(self) -> None:
+ """Tests that receiving a state event that's not a membership update does
+ nothing.
+ """
+ invite = MockEvent(
+ sender=self.user_id,
+ state_key=self.user_id,
+ type="org.matrix.test",
+ content={"foo": "bar"},
+ )
+
+ # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+ # EventBase.
+ await self.module.on_new_event(event=invite) # type: ignore[arg-type]
+
+ self.mocked_update_membership.assert_not_called()
+
+ def test_config_parse(self) -> None:
+ """Tests that a correct configuration parses."""
+ config = {
+ "auto_accept_invites": {
+ "enabled": True,
+ "only_for_direct_messages": True,
+ "only_from_local_users": True,
+ }
+ }
+ parsed_config = AutoAcceptInvitesConfig()
+ parsed_config.read_config(config)
+
+ self.assertTrue(parsed_config.enabled)
+ self.assertTrue(parsed_config.accept_invites_only_for_direct_messages)
+ self.assertTrue(parsed_config.accept_invites_only_from_local_users)
+
+ def test_runs_on_only_one_worker(self) -> None:
+ """
+ Tests that the module only runs on the specified worker.
+ """
+ # By default, we run on the main process...
+ main_module = create_module(
+ config_override={"auto_accept_invites": {"enabled": True}}, worker_name=None
+ )
+ cast(
+ Mock, main_module._api.register_third_party_rules_callbacks
+ ).assert_called_once()
+
+ # ...and not on other workers (like synchrotrons)...
+ sync_module = create_module(worker_name="synchrotron42")
+ cast(
+ Mock, sync_module._api.register_third_party_rules_callbacks
+ ).assert_not_called()
+
+ # ...unless we configured them to be the designated worker.
+ specified_module = create_module(
+ config_override={
+ "auto_accept_invites": {
+ "enabled": True,
+ "worker_to_run_on": "account_data1",
+ }
+ },
+ worker_name="account_data1",
+ )
+ cast(
+ Mock, specified_module._api.register_third_party_rules_callbacks
+ ).assert_called_once()
+
+ async def retry_assertions(
+ self, mock: Mock, call_count: int, **kwargs: Any
+ ) -> None:
+ """
+ This is a hacky way to ensure that the assertions are not called before the other coroutine
+ has a chance to call `update_room_membership`. It catches the exception caused by a failure,
+ and sleeps the thread before retrying, up until 5 tries.
+
+ Args:
+ call_count: the number of times the mock should have been called
+ mock: the mocked function we want to assert on
+ kwargs: keyword arguments to assert that the mock was called with
+ """
+
+ i = 0
+ while i < 5:
+ try:
+ # Check that the mocked method is called the expected amount of times and with the right
+ # arguments to attempt to make the user join the room.
+ mock.assert_called_with(**kwargs)
+ self.assertEqual(call_count, mock.call_count)
+ break
+ except AssertionError as e:
+ i += 1
+ if i == 5:
+ # we've used up the tries, force the test to fail as we've already caught the exception
+ self.fail(e)
+ await asyncio.sleep(1)
+
+
+@attr.s(auto_attribs=True)
+class MockEvent:
+ """Mocks an event. Only exposes properties the module uses."""
+
+ sender: str
+ type: str
+ content: Dict[str, Any]
+ room_id: str = "!someroom"
+ state_key: Optional[str] = None
+
+ def is_state(self) -> bool:
+ """Checks if the event is a state event by checking if it has a state key."""
+ return self.state_key is not None
+
+ @property
+ def membership(self) -> str:
+ """Extracts the membership from the event. Should only be called on an event
+ that's a membership event, and will raise a KeyError otherwise.
+ """
+ membership: str = self.content["membership"]
+ return membership
+
+
+T = TypeVar("T")
+TV = TypeVar("TV")
+
+
+async def make_awaitable(value: T) -> T:
+ return value
+
+
+def make_multiple_awaitable(result: TV) -> Awaitable[TV]:
+ """
+ Makes an awaitable, suitable for mocking an `async` function.
+ This uses Futures as they can be awaited multiple times so can be returned
+ to multiple callers.
+ """
+ future: Future[TV] = Future()
+ future.set_result(result)
+ return future
+
+
+def create_module(
+ config_override: Optional[Dict[str, Any]] = None, worker_name: Optional[str] = None
+) -> InviteAutoAccepter:
+ # Create a mock based on the ModuleApi spec, but override some mocked functions
+ # because some capabilities are needed for running the tests.
+ module_api = Mock(spec=ModuleApi)
+ module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test"
+ module_api.worker_name = worker_name
+ module_api.sleep.return_value = make_multiple_awaitable(None)
+
+ if config_override is None:
+ config_override = {}
+
+ config = AutoAcceptInvitesConfig()
+ config.read_config(config_override)
+
+ return InviteAutoAccepter(config, module_api)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index fe00afe198..7362bde7ab 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -170,6 +170,7 @@ class RestHelper:
targ: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
+ extra_data: Optional[dict] = None,
) -> JsonDict:
return self.change_membership(
room=room,
@@ -178,6 +179,7 @@ class RestHelper:
tok=tok,
membership=Membership.INVITE,
expect_code=expect_code,
+ extra_data=extra_data,
)
def join(
diff --git a/tests/server.py b/tests/server.py
index 434be3d22c..f3a917f835 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -85,6 +85,7 @@ from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
+from synapse.events.auto_accept_invites import InviteAutoAccepter
from synapse.events.presence_router import load_legacy_presence_router
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest
@@ -1156,6 +1157,11 @@ def setup_test_homeserver(
for module, module_config in hs.config.modules.loaded_modules:
module(config=module_config, api=module_api)
+ if hs.config.auto_accept_invites.enabled:
+ # Start the local auto_accept_invites module.
+ m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
+ logger.info("Loaded local module %s", m)
+
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
|