summary refs log tree commit diff
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2024-05-21 20:09:17 +0000
committerGitHub <noreply@github.com>2024-05-21 20:09:17 +0000
commit6a9a641fb86b04587840bcb6b76af9a0acef9b54 (patch)
treea8c0bace62f61de7b605c23635d85d111abf11d9
parentImprove perf of sync device lists (#17216) (diff)
downloadsynapse-6a9a641fb86b04587840bcb6b76af9a0acef9b54.tar.xz
Bring auto-accept invite logic into Synapse (#17147)
This PR ports the logic from the
[synapse_auto_accept_invite](https://github.com/matrix-org/synapse-auto-accept-invite)
module into synapse.

I went with the naive approach of injecting the "module" next to where
third party modules are currently loaded. If there is a better/preferred
way to handle this, I'm all ears. It wasn't obvious to me if there was a
better location to add this logic that would cleanly apply to all
incoming invite events.

Relies on https://github.com/element-hq/synapse/pull/17166 to fix linter
errors.
-rw-r--r--changelog.d/17147.feature1
-rw-r--r--docs/usage/configuration/config_documentation.md29
-rw-r--r--synapse/app/_base.py6
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/auto_accept_invites.py43
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/events/auto_accept_invites.py196
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--tests/events/test_auto_accept_invites.py657
-rw-r--r--tests/rest/client/utils.py2
-rw-r--r--tests/server.py6
11 files changed, 945 insertions, 1 deletions
diff --git a/changelog.d/17147.feature b/changelog.d/17147.feature
new file mode 100644
index 0000000000..7c2cdb6bdf
--- /dev/null
+++ b/changelog.d/17147.feature
@@ -0,0 +1 @@
+Add the ability to auto-accept invites on the behalf of users. See the [`auto_accept_invites`](https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#auto-accept-invites) config option for details.
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index e04fdfdfb0..2c917d1f8e 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -4595,3 +4595,32 @@ background_updates:
     min_batch_size: 10
     default_batch_size: 50
 ```
+---
+## Auto Accept Invites
+Configuration settings related to automatically accepting invites.
+
+---
+### `auto_accept_invites`
+
+Automatically accepting invites controls whether users are presented with an invite request or if they
+are instead automatically joined to a room when receiving an invite. Set the `enabled` sub-option to true to
+enable auto-accepting invites. Defaults to false.
+This setting has the following sub-options:
+* `enabled`: Whether to run the auto-accept invites logic. Defaults to false.
+* `only_for_direct_messages`: Whether invites should be automatically accepted for all room types, or only
+   for direct messages. Defaults to false.
+* `only_from_local_users`: Whether to only automatically accept invites from users on this homeserver. Defaults to false.
+* `worker_to_run_on`: Which worker to run this module on. This must match the "worker_name".
+
+NOTE: Care should be taken not to enable this setting if the `synapse_auto_accept_invite` module is enabled and installed.
+The two modules will compete to perform the same task and may result in undesired behaviour. For example, multiple join
+events could be generated from a single invite.
+
+Example configuration:
+```yaml
+auto_accept_invites:
+    enabled: true
+    only_for_direct_messages: true
+    only_from_local_users: true
+    worker_to_run_on: "worker_1"
+```
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3182608f73..67e0df1459 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -68,6 +68,7 @@ from synapse.config._base import format_config_error
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig, ManholeConfig, TCPListenerConfig
 from synapse.crypto import context_factory
+from synapse.events.auto_accept_invites import InviteAutoAccepter
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.http.site import SynapseSite
@@ -582,6 +583,11 @@ async def start(hs: "HomeServer") -> None:
         m = module(config, module_api)
         logger.info("Loaded module %s", m)
 
+    if hs.config.auto_accept_invites.enabled:
+        # Start the local auto_accept_invites module.
+        m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
+        logger.info("Loaded local module %s", m)
+
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
     load_legacy_presence_router(hs)
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index fc51aed234..d9cb0da38b 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -23,6 +23,7 @@ from synapse.config import (  # noqa: F401
     api,
     appservice,
     auth,
+    auto_accept_invites,
     background_updates,
     cache,
     captcha,
@@ -120,6 +121,7 @@ class RootConfig:
     federation: federation.FederationConfig
     retention: retention.RetentionConfig
     background_updates: background_updates.BackgroundUpdateConfig
+    auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig
 
     config_classes: List[Type["Config"]] = ...
     config_files: List[str]
diff --git a/synapse/config/auto_accept_invites.py b/synapse/config/auto_accept_invites.py
new file mode 100644
index 0000000000..d90e13a510
--- /dev/null
+++ b/synapse/config/auto_accept_invites.py
@@ -0,0 +1,43 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+from typing import Any
+
+from synapse.types import JsonDict
+
+from ._base import Config
+
+
+class AutoAcceptInvitesConfig(Config):
+    section = "auto_accept_invites"
+
+    def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+        auto_accept_invites_config = config.get("auto_accept_invites") or {}
+
+        self.enabled = auto_accept_invites_config.get("enabled", False)
+
+        self.accept_invites_only_for_direct_messages = auto_accept_invites_config.get(
+            "only_for_direct_messages", False
+        )
+
+        self.accept_invites_only_from_local_users = auto_accept_invites_config.get(
+            "only_from_local_users", False
+        )
+
+        self.worker_to_run_on = auto_accept_invites_config.get("worker_to_run_on")
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 72e93ed04f..e36c0bd6ae 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -23,6 +23,7 @@ from .account_validity import AccountValidityConfig
 from .api import ApiConfig
 from .appservice import AppServiceConfig
 from .auth import AuthConfig
+from .auto_accept_invites import AutoAcceptInvitesConfig
 from .background_updates import BackgroundUpdateConfig
 from .cache import CacheConfig
 from .captcha import CaptchaConfig
@@ -105,4 +106,5 @@ class HomeServerConfig(RootConfig):
         RedisConfig,
         ExperimentalConfig,
         BackgroundUpdateConfig,
+        AutoAcceptInvitesConfig,
     ]
diff --git a/synapse/events/auto_accept_invites.py b/synapse/events/auto_accept_invites.py
new file mode 100644
index 0000000000..d88ec51d9d
--- /dev/null
+++ b/synapse/events/auto_accept_invites.py
@@ -0,0 +1,196 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2021 The Matrix.org Foundation C.I.C
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import logging
+from http import HTTPStatus
+from typing import Any, Dict, Tuple
+
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.errors import SynapseError
+from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
+from synapse.module_api import EventBase, ModuleApi, run_as_background_process
+
+logger = logging.getLogger(__name__)
+
+
+class InviteAutoAccepter:
+    def __init__(self, config: AutoAcceptInvitesConfig, api: ModuleApi):
+        # Keep a reference to the Module API.
+        self._api = api
+        self._config = config
+
+        if not self._config.enabled:
+            return
+
+        should_run_on_this_worker = config.worker_to_run_on == self._api.worker_name
+
+        if not should_run_on_this_worker:
+            logger.info(
+                "Not accepting invites on this worker (configured: %r, here: %r)",
+                config.worker_to_run_on,
+                self._api.worker_name,
+            )
+            return
+
+        logger.info(
+            "Accepting invites on this worker (here: %r)", self._api.worker_name
+        )
+
+        # Register the callback.
+        self._api.register_third_party_rules_callbacks(
+            on_new_event=self.on_new_event,
+        )
+
+    async def on_new_event(self, event: EventBase, *args: Any) -> None:
+        """Listens for new events, and if the event is an invite for a local user then
+        automatically accepts it.
+
+        Args:
+            event: The incoming event.
+        """
+        # Check if the event is an invite for a local user.
+        is_invite_for_local_user = (
+            event.type == EventTypes.Member
+            and event.is_state()
+            and event.membership == Membership.INVITE
+            and self._api.is_mine(event.state_key)
+        )
+
+        # Only accept invites for direct messages if the configuration mandates it.
+        is_direct_message = event.content.get("is_direct", False)
+        is_allowed_by_direct_message_rules = (
+            not self._config.accept_invites_only_for_direct_messages
+            or is_direct_message is True
+        )
+
+        # Only accept invites from remote users if the configuration mandates it.
+        is_from_local_user = self._api.is_mine(event.sender)
+        is_allowed_by_local_user_rules = (
+            not self._config.accept_invites_only_from_local_users
+            or is_from_local_user is True
+        )
+
+        if (
+            is_invite_for_local_user
+            and is_allowed_by_direct_message_rules
+            and is_allowed_by_local_user_rules
+        ):
+            # Make the user join the room. We run this as a background process to circumvent a race condition
+            # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
+            run_as_background_process(
+                "retry_make_join",
+                self._retry_make_join,
+                event.state_key,
+                event.state_key,
+                event.room_id,
+                "join",
+                bg_start_span=False,
+            )
+
+            if is_direct_message:
+                # Mark this room as a direct message!
+                await self._mark_room_as_direct_message(
+                    event.state_key, event.sender, event.room_id
+                )
+
+    async def _mark_room_as_direct_message(
+        self, user_id: str, dm_user_id: str, room_id: str
+    ) -> None:
+        """
+        Marks a room (`room_id`) as a direct message with the counterparty `dm_user_id`
+        from the perspective of the user `user_id`.
+
+        Args:
+            user_id: the user for whom the membership is changing
+            dm_user_id: the user performing the membership change
+            room_id: room id of the room the user is invited to
+        """
+
+        # This is a dict of User IDs to tuples of Room IDs
+        # (get_global will return a frozendict of tuples as it freezes the data,
+        # but we should accept either frozen or unfrozen variants.)
+        # Be careful: we convert the outer frozendict into a dict here,
+        # but the contents of the dict are still frozen (tuples in lieu of lists,
+        # etc.)
+        dm_map: Dict[str, Tuple[str, ...]] = dict(
+            await self._api.account_data_manager.get_global(
+                user_id, AccountDataTypes.DIRECT
+            )
+            or {}
+        )
+
+        if dm_user_id not in dm_map:
+            dm_map[dm_user_id] = (room_id,)
+        else:
+            dm_rooms_for_user = dm_map[dm_user_id]
+            assert isinstance(dm_rooms_for_user, (tuple, list))
+
+            dm_map[dm_user_id] = tuple(dm_rooms_for_user) + (room_id,)
+
+        await self._api.account_data_manager.put_global(
+            user_id, AccountDataTypes.DIRECT, dm_map
+        )
+
+    async def _retry_make_join(
+        self, sender: str, target: str, room_id: str, new_membership: str
+    ) -> None:
+        """
+        A function to retry sending the `make_join` request with an increasing backoff. This is
+        implemented to work around a race condition when receiving invites over federation.
+
+        Args:
+            sender: the user performing the membership change
+            target: the user for whom the membership is changing
+            room_id: room id of the room to join to
+            new_membership: the type of membership event (in this case will be "join")
+        """
+
+        sleep = 0
+        retries = 0
+        join_event = None
+
+        while retries < 5:
+            try:
+                await self._api.sleep(sleep)
+                join_event = await self._api.update_room_membership(
+                    sender=sender,
+                    target=target,
+                    room_id=room_id,
+                    new_membership=new_membership,
+                )
+            except SynapseError as e:
+                if e.code == HTTPStatus.FORBIDDEN:
+                    logger.debug(
+                        f"Update_room_membership was forbidden. This can sometimes be expected for remote invites. Exception: {e}"
+                    )
+                else:
+                    logger.warn(
+                        f"Update_room_membership raised the following unexpected (SynapseError) exception: {e}"
+                    )
+            except Exception as e:
+                logger.warn(
+                    f"Update_room_membership raised the following unexpected exception: {e}"
+                )
+
+            sleep = 2**retries
+            retries += 1
+
+            if join_event is not None:
+                break
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index f275d4f35a..ee74289b6c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -817,7 +817,7 @@ class SsoHandler:
                 server_name = profile["avatar_url"].split("/")[-2]
                 media_id = profile["avatar_url"].split("/")[-1]
                 if self._is_mine_server_name(server_name):
-                    media = await self._media_repo.store.get_local_media(media_id)
+                    media = await self._media_repo.store.get_local_media(media_id)  # type: ignore[has-type]
                     if media is not None and upload_name == media.upload_name:
                         logger.info("skipping saving the user avatar")
                         return True
diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py
new file mode 100644
index 0000000000..7fb4d4fa90
--- /dev/null
+++ b/tests/events/test_auto_accept_invites.py
@@ -0,0 +1,657 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2021 The Matrix.org Foundation C.I.C
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import asyncio
+from asyncio import Future
+from http import HTTPStatus
+from typing import Any, Awaitable, Dict, List, Optional, Tuple, TypeVar, cast
+from unittest.mock import Mock
+
+import attr
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
+from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
+from synapse.events.auto_accept_invites import InviteAutoAccepter
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey, SyncVersion
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import StreamToken, create_requester
+from synapse.util import Clock
+
+from tests.handlers.test_sync import generate_sync_config
+from tests.unittest import (
+    FederatingHomeserverTestCase,
+    HomeserverTestCase,
+    TestCase,
+    override_config,
+)
+
+
+class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
+    """
+    Integration test cases for auto-accepting invites.
+    """
+
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        hs = self.setup_test_homeserver()
+        self.handler = hs.get_federation_handler()
+        self.store = hs.get_datastores().main
+        return hs
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.sync_handler = self.hs.get_sync_handler()
+        self.module_api = hs.get_module_api()
+
+    @parameterized.expand(
+        [
+            [False],
+            [True],
+        ]
+    )
+    @override_config(
+        {
+            "auto_accept_invites": {
+                "enabled": True,
+            },
+        }
+    )
+    def test_auto_accept_invites(self, direct_room: bool) -> None:
+        """Test that a user automatically joins a room when invited, if the
+        module is enabled.
+        """
+        # A local user who sends an invite
+        inviting_user_id = self.register_user("inviter", "pass")
+        inviting_user_tok = self.login("inviter", "pass")
+
+        # A local user who receives an invite
+        invited_user_id = self.register_user("invitee", "pass")
+        self.login("invitee", "pass")
+
+        # Create a room and send an invite to the other user
+        room_id = self.helper.create_room_as(
+            inviting_user_id,
+            is_public=False,
+            tok=inviting_user_tok,
+        )
+
+        self.helper.invite(
+            room_id,
+            inviting_user_id,
+            invited_user_id,
+            tok=inviting_user_tok,
+            extra_data={"is_direct": direct_room},
+        )
+
+        # Check that the invite receiving user has automatically joined the room when syncing
+        join_updates, _ = sync_join(self, invited_user_id)
+        self.assertEqual(len(join_updates), 1)
+
+        join_update: JoinedSyncResult = join_updates[0]
+        self.assertEqual(join_update.room_id, room_id)
+
+    @override_config(
+        {
+            "auto_accept_invites": {
+                "enabled": False,
+            },
+        }
+    )
+    def test_module_not_enabled(self) -> None:
+        """Test that a user does not automatically join a room when invited,
+        if the module is not enabled.
+        """
+        # A local user who sends an invite
+        inviting_user_id = self.register_user("inviter", "pass")
+        inviting_user_tok = self.login("inviter", "pass")
+
+        # A local user who receives an invite
+        invited_user_id = self.register_user("invitee", "pass")
+        self.login("invitee", "pass")
+
+        # Create a room and send an invite to the other user
+        room_id = self.helper.create_room_as(
+            inviting_user_id, is_public=False, tok=inviting_user_tok
+        )
+
+        self.helper.invite(
+            room_id,
+            inviting_user_id,
+            invited_user_id,
+            tok=inviting_user_tok,
+        )
+
+        # Check that the invite receiving user has not automatically joined the room when syncing
+        join_updates, _ = sync_join(self, invited_user_id)
+        self.assertEqual(len(join_updates), 0)
+
+    @override_config(
+        {
+            "auto_accept_invites": {
+                "enabled": True,
+            },
+        }
+    )
+    def test_invite_from_remote_user(self) -> None:
+        """Test that an invite from a remote user results in the invited user
+        automatically joining the room.
+        """
+        # A remote user who sends the invite
+        remote_server = "otherserver"
+        remote_user = "@otheruser:" + remote_server
+
+        # A local user who creates the room
+        creator_user_id = self.register_user("creator", "pass")
+        creator_user_tok = self.login("creator", "pass")
+
+        # A local user who receives an invite
+        invited_user_id = self.register_user("invitee", "pass")
+        self.login("invitee", "pass")
+
+        room_id = self.helper.create_room_as(
+            room_creator=creator_user_id, tok=creator_user_tok
+        )
+        room_version = self.get_success(self.store.get_room_version(room_id))
+
+        invite_event = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "content": {"membership": "invite"},
+                "room_id": room_id,
+                "sender": remote_user,
+                "state_key": invited_user_id,
+                "depth": 32,
+                "prev_events": [],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            room_version,
+        )
+        self.get_success(
+            self.handler.on_invite_request(
+                remote_server,
+                invite_event,
+                invite_event.room_version,
+            )
+        )
+
+        # Check that the invite receiving user has automatically joined the room when syncing
+        join_updates, _ = sync_join(self, invited_user_id)
+        self.assertEqual(len(join_updates), 1)
+
+        join_update: JoinedSyncResult = join_updates[0]
+        self.assertEqual(join_update.room_id, room_id)
+
+    @parameterized.expand(
+        [
+            [False, False],
+            [True, True],
+        ]
+    )
+    @override_config(
+        {
+            "auto_accept_invites": {
+                "enabled": True,
+                "only_for_direct_messages": True,
+            },
+        }
+    )
+    def test_accept_invite_direct_message(
+        self,
+        direct_room: bool,
+        expect_auto_join: bool,
+    ) -> None:
+        """Tests that, if the module is configured to only accept DM invites, invites to DM rooms are still
+        automatically accepted. Otherwise they are rejected.
+        """
+        # A local user who sends an invite
+        inviting_user_id = self.register_user("inviter", "pass")
+        inviting_user_tok = self.login("inviter", "pass")
+
+        # A local user who receives an invite
+        invited_user_id = self.register_user("invitee", "pass")
+        self.login("invitee", "pass")
+
+        # Create a room and send an invite to the other user
+        room_id = self.helper.create_room_as(
+            inviting_user_id,
+            is_public=False,
+            tok=inviting_user_tok,
+        )
+
+        self.helper.invite(
+            room_id,
+            inviting_user_id,
+            invited_user_id,
+            tok=inviting_user_tok,
+            extra_data={"is_direct": direct_room},
+        )
+
+        if expect_auto_join:
+            # Check that the invite receiving user has automatically joined the room when syncing
+            join_updates, _ = sync_join(self, invited_user_id)
+            self.assertEqual(len(join_updates), 1)
+
+            join_update: JoinedSyncResult = join_updates[0]
+            self.assertEqual(join_update.room_id, room_id)
+        else:
+            # Check that the invite receiving user has not automatically joined the room when syncing
+            join_updates, _ = sync_join(self, invited_user_id)
+            self.assertEqual(len(join_updates), 0)
+
+    @parameterized.expand(
+        [
+            [False, True],
+            [True, False],
+        ]
+    )
+    @override_config(
+        {
+            "auto_accept_invites": {
+                "enabled": True,
+                "only_from_local_users": True,
+            },
+        }
+    )
+    def test_accept_invite_local_user(
+        self, remote_inviter: bool, expect_auto_join: bool
+    ) -> None:
+        """Tests that, if the module is configured to only accept invites from local users, invites
+        from local users are still automatically accepted. Otherwise they are rejected.
+        """
+        # A local user who sends an invite
+        creator_user_id = self.register_user("inviter", "pass")
+        creator_user_tok = self.login("inviter", "pass")
+
+        # A local user who receives an invite
+        invited_user_id = self.register_user("invitee", "pass")
+        self.login("invitee", "pass")
+
+        # Create a room and send an invite to the other user
+        room_id = self.helper.create_room_as(
+            creator_user_id, is_public=False, tok=creator_user_tok
+        )
+
+        if remote_inviter:
+            room_version = self.get_success(self.store.get_room_version(room_id))
+
+            # A remote user who sends the invite
+            remote_server = "otherserver"
+            remote_user = "@otheruser:" + remote_server
+
+            invite_event = event_from_pdu_json(
+                {
+                    "type": EventTypes.Member,
+                    "content": {"membership": "invite"},
+                    "room_id": room_id,
+                    "sender": remote_user,
+                    "state_key": invited_user_id,
+                    "depth": 32,
+                    "prev_events": [],
+                    "auth_events": [],
+                    "origin_server_ts": self.clock.time_msec(),
+                },
+                room_version,
+            )
+            self.get_success(
+                self.handler.on_invite_request(
+                    remote_server,
+                    invite_event,
+                    invite_event.room_version,
+                )
+            )
+        else:
+            self.helper.invite(
+                room_id,
+                creator_user_id,
+                invited_user_id,
+                tok=creator_user_tok,
+            )
+
+        if expect_auto_join:
+            # Check that the invite receiving user has automatically joined the room when syncing
+            join_updates, _ = sync_join(self, invited_user_id)
+            self.assertEqual(len(join_updates), 1)
+
+            join_update: JoinedSyncResult = join_updates[0]
+            self.assertEqual(join_update.room_id, room_id)
+        else:
+            # Check that the invite receiving user has not automatically joined the room when syncing
+            join_updates, _ = sync_join(self, invited_user_id)
+            self.assertEqual(len(join_updates), 0)
+
+
+_request_key = 0
+
+
+def generate_request_key() -> SyncRequestKey:
+    global _request_key
+    _request_key += 1
+    return ("request_key", _request_key)
+
+
+def sync_join(
+    testcase: HomeserverTestCase,
+    user_id: str,
+    since_token: Optional[StreamToken] = None,
+) -> Tuple[List[JoinedSyncResult], StreamToken]:
+    """Perform a sync request for the given user and return the user join updates
+    they've received, as well as the next_batch token.
+
+    This method assumes testcase.sync_handler points to the homeserver's sync handler.
+
+    Args:
+        testcase: The testcase that is currently being run.
+        user_id: The ID of the user to generate a sync response for.
+        since_token: An optional token to indicate from at what point to sync from.
+
+    Returns:
+        A tuple containing a list of join updates, and the sync response's
+        next_batch token.
+    """
+    requester = create_requester(user_id)
+    sync_config = generate_sync_config(requester.user.to_string())
+    sync_result = testcase.get_success(
+        testcase.hs.get_sync_handler().wait_for_sync_for_user(
+            requester,
+            sync_config,
+            SyncVersion.SYNC_V2,
+            generate_request_key(),
+            since_token,
+        )
+    )
+
+    return sync_result.joined, sync_result.next_batch
+
+
+class InviteAutoAccepterInternalTestCase(TestCase):
+    """
+    Test cases which exercise the internals of the InviteAutoAccepter.
+    """
+
+    def setUp(self) -> None:
+        self.module = create_module()
+        self.user_id = "@peter:test"
+        self.invitee = "@lesley:test"
+        self.remote_invitee = "@thomas:remote"
+
+        # We know our module API is a mock, but mypy doesn't.
+        self.mocked_update_membership: Mock = self.module._api.update_room_membership  # type: ignore[assignment]
+
+    async def test_accept_invite_with_failures(self) -> None:
+        """Tests that receiving an invite for a local user makes the module attempt to
+        make the invitee join the room. This test verifies that it works if the call to
+        update membership returns exceptions before successfully completing and returning an event.
+        """
+        invite = MockEvent(
+            sender="@inviter:test",
+            state_key="@invitee:test",
+            type="m.room.member",
+            content={"membership": "invite"},
+        )
+
+        join_event = MockEvent(
+            sender="someone",
+            state_key="someone",
+            type="m.room.member",
+            content={"membership": "join"},
+        )
+        # the first two calls raise an exception while the third call is successful
+        self.mocked_update_membership.side_effect = [
+            SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
+            SynapseError(HTTPStatus.FORBIDDEN, "Forbidden"),
+            make_awaitable(join_event),
+        ]
+
+        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+        # EventBase.
+        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]
+
+        await self.retry_assertions(
+            self.mocked_update_membership,
+            3,
+            sender=invite.state_key,
+            target=invite.state_key,
+            room_id=invite.room_id,
+            new_membership="join",
+        )
+
+    async def test_accept_invite_failures(self) -> None:
+        """Tests that receiving an invite for a local user makes the module attempt to
+        make the invitee join the room. This test verifies that if the update_membership call
+        fails consistently, _retry_make_join will break the loop after the set number of retries and
+        execution will continue.
+        """
+        invite = MockEvent(
+            sender=self.user_id,
+            state_key=self.invitee,
+            type="m.room.member",
+            content={"membership": "invite"},
+        )
+        self.mocked_update_membership.side_effect = SynapseError(
+            HTTPStatus.FORBIDDEN, "Forbidden"
+        )
+
+        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+        # EventBase.
+        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]
+
+        await self.retry_assertions(
+            self.mocked_update_membership,
+            5,
+            sender=invite.state_key,
+            target=invite.state_key,
+            room_id=invite.room_id,
+            new_membership="join",
+        )
+
+    async def test_not_state(self) -> None:
+        """Tests that receiving an invite that's not a state event does nothing."""
+        invite = MockEvent(
+            sender=self.user_id, type="m.room.member", content={"membership": "invite"}
+        )
+
+        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+        # EventBase.
+        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]
+
+        self.mocked_update_membership.assert_not_called()
+
+    async def test_not_invite(self) -> None:
+        """Tests that receiving a membership update that's not an invite does nothing."""
+        invite = MockEvent(
+            sender=self.user_id,
+            state_key=self.user_id,
+            type="m.room.member",
+            content={"membership": "join"},
+        )
+
+        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+        # EventBase.
+        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]
+
+        self.mocked_update_membership.assert_not_called()
+
+    async def test_not_membership(self) -> None:
+        """Tests that receiving a state event that's not a membership update does
+        nothing.
+        """
+        invite = MockEvent(
+            sender=self.user_id,
+            state_key=self.user_id,
+            type="org.matrix.test",
+            content={"foo": "bar"},
+        )
+
+        # Stop mypy from complaining that we give on_new_event a MockEvent rather than an
+        # EventBase.
+        await self.module.on_new_event(event=invite)  # type: ignore[arg-type]
+
+        self.mocked_update_membership.assert_not_called()
+
+    def test_config_parse(self) -> None:
+        """Tests that a correct configuration parses."""
+        config = {
+            "auto_accept_invites": {
+                "enabled": True,
+                "only_for_direct_messages": True,
+                "only_from_local_users": True,
+            }
+        }
+        parsed_config = AutoAcceptInvitesConfig()
+        parsed_config.read_config(config)
+
+        self.assertTrue(parsed_config.enabled)
+        self.assertTrue(parsed_config.accept_invites_only_for_direct_messages)
+        self.assertTrue(parsed_config.accept_invites_only_from_local_users)
+
+    def test_runs_on_only_one_worker(self) -> None:
+        """
+        Tests that the module only runs on the specified worker.
+        """
+        # By default, we run on the main process...
+        main_module = create_module(
+            config_override={"auto_accept_invites": {"enabled": True}}, worker_name=None
+        )
+        cast(
+            Mock, main_module._api.register_third_party_rules_callbacks
+        ).assert_called_once()
+
+        # ...and not on other workers (like synchrotrons)...
+        sync_module = create_module(worker_name="synchrotron42")
+        cast(
+            Mock, sync_module._api.register_third_party_rules_callbacks
+        ).assert_not_called()
+
+        # ...unless we configured them to be the designated worker.
+        specified_module = create_module(
+            config_override={
+                "auto_accept_invites": {
+                    "enabled": True,
+                    "worker_to_run_on": "account_data1",
+                }
+            },
+            worker_name="account_data1",
+        )
+        cast(
+            Mock, specified_module._api.register_third_party_rules_callbacks
+        ).assert_called_once()
+
+    async def retry_assertions(
+        self, mock: Mock, call_count: int, **kwargs: Any
+    ) -> None:
+        """
+        This is a hacky way to ensure that the assertions are not called before the other coroutine
+        has a chance to call `update_room_membership`. It catches the exception caused by a failure,
+        and sleeps the thread before retrying, up until 5 tries.
+
+        Args:
+            call_count: the number of times the mock should have been called
+            mock: the mocked function we want to assert on
+            kwargs: keyword arguments to assert that the mock was called with
+        """
+
+        i = 0
+        while i < 5:
+            try:
+                # Check that the mocked method is called the expected amount of times and with the right
+                # arguments to attempt to make the user join the room.
+                mock.assert_called_with(**kwargs)
+                self.assertEqual(call_count, mock.call_count)
+                break
+            except AssertionError as e:
+                i += 1
+                if i == 5:
+                    # we've used up the tries, force the test to fail as we've already caught the exception
+                    self.fail(e)
+                await asyncio.sleep(1)
+
+
+@attr.s(auto_attribs=True)
+class MockEvent:
+    """Mocks an event. Only exposes properties the module uses."""
+
+    sender: str
+    type: str
+    content: Dict[str, Any]
+    room_id: str = "!someroom"
+    state_key: Optional[str] = None
+
+    def is_state(self) -> bool:
+        """Checks if the event is a state event by checking if it has a state key."""
+        return self.state_key is not None
+
+    @property
+    def membership(self) -> str:
+        """Extracts the membership from the event. Should only be called on an event
+        that's a membership event, and will raise a KeyError otherwise.
+        """
+        membership: str = self.content["membership"]
+        return membership
+
+
+T = TypeVar("T")
+TV = TypeVar("TV")
+
+
+async def make_awaitable(value: T) -> T:
+    return value
+
+
+def make_multiple_awaitable(result: TV) -> Awaitable[TV]:
+    """
+    Makes an awaitable, suitable for mocking an `async` function.
+    This uses Futures as they can be awaited multiple times so can be returned
+    to multiple callers.
+    """
+    future: Future[TV] = Future()
+    future.set_result(result)
+    return future
+
+
+def create_module(
+    config_override: Optional[Dict[str, Any]] = None, worker_name: Optional[str] = None
+) -> InviteAutoAccepter:
+    # Create a mock based on the ModuleApi spec, but override some mocked functions
+    # because some capabilities are needed for running the tests.
+    module_api = Mock(spec=ModuleApi)
+    module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test"
+    module_api.worker_name = worker_name
+    module_api.sleep.return_value = make_multiple_awaitable(None)
+
+    if config_override is None:
+        config_override = {}
+
+    config = AutoAcceptInvitesConfig()
+    config.read_config(config_override)
+
+    return InviteAutoAccepter(config, module_api)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index fe00afe198..7362bde7ab 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -170,6 +170,7 @@ class RestHelper:
         targ: Optional[str] = None,
         expect_code: int = HTTPStatus.OK,
         tok: Optional[str] = None,
+        extra_data: Optional[dict] = None,
     ) -> JsonDict:
         return self.change_membership(
             room=room,
@@ -178,6 +179,7 @@ class RestHelper:
             tok=tok,
             membership=Membership.INVITE,
             expect_code=expect_code,
+            extra_data=extra_data,
         )
 
     def join(
diff --git a/tests/server.py b/tests/server.py
index 434be3d22c..f3a917f835 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -85,6 +85,7 @@ from twisted.web.server import Request, Site
 
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
+from synapse.events.auto_accept_invites import InviteAutoAccepter
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.http.site import SynapseRequest
@@ -1156,6 +1157,11 @@ def setup_test_homeserver(
     for module, module_config in hs.config.modules.loaded_modules:
         module(config=module_config, api=module_api)
 
+    if hs.config.auto_accept_invites.enabled:
+        # Start the local auto_accept_invites module.
+        m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
+        logger.info("Loaded local module %s", m)
+
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
     load_legacy_presence_router(hs)