summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py6
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/auto_accept_invites.py43
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/events/auto_accept_invites.py196
-rw-r--r--synapse/handlers/sso.py2
6 files changed, 250 insertions, 1 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3182608f73..67e0df1459 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -68,6 +68,7 @@ from synapse.config._base import format_config_error
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig, ManholeConfig, TCPListenerConfig
 from synapse.crypto import context_factory
+from synapse.events.auto_accept_invites import InviteAutoAccepter
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.http.site import SynapseSite
@@ -582,6 +583,11 @@ async def start(hs: "HomeServer") -> None:
         m = module(config, module_api)
         logger.info("Loaded module %s", m)
 
+    if hs.config.auto_accept_invites.enabled:
+        # Start the local auto_accept_invites module.
+        m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
+        logger.info("Loaded local module %s", m)
+
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
     load_legacy_presence_router(hs)
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index fc51aed234..d9cb0da38b 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -23,6 +23,7 @@ from synapse.config import (  # noqa: F401
     api,
     appservice,
     auth,
+    auto_accept_invites,
     background_updates,
     cache,
     captcha,
@@ -120,6 +121,7 @@ class RootConfig:
     federation: federation.FederationConfig
     retention: retention.RetentionConfig
     background_updates: background_updates.BackgroundUpdateConfig
+    auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig
 
     config_classes: List[Type["Config"]] = ...
     config_files: List[str]
diff --git a/synapse/config/auto_accept_invites.py b/synapse/config/auto_accept_invites.py
new file mode 100644
index 0000000000..d90e13a510
--- /dev/null
+++ b/synapse/config/auto_accept_invites.py
@@ -0,0 +1,43 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+from typing import Any
+
+from synapse.types import JsonDict
+
+from ._base import Config
+
+
+class AutoAcceptInvitesConfig(Config):
+    section = "auto_accept_invites"
+
+    def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+        auto_accept_invites_config = config.get("auto_accept_invites") or {}
+
+        self.enabled = auto_accept_invites_config.get("enabled", False)
+
+        self.accept_invites_only_for_direct_messages = auto_accept_invites_config.get(
+            "only_for_direct_messages", False
+        )
+
+        self.accept_invites_only_from_local_users = auto_accept_invites_config.get(
+            "only_from_local_users", False
+        )
+
+        self.worker_to_run_on = auto_accept_invites_config.get("worker_to_run_on")
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 72e93ed04f..e36c0bd6ae 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -23,6 +23,7 @@ from .account_validity import AccountValidityConfig
 from .api import ApiConfig
 from .appservice import AppServiceConfig
 from .auth import AuthConfig
+from .auto_accept_invites import AutoAcceptInvitesConfig
 from .background_updates import BackgroundUpdateConfig
 from .cache import CacheConfig
 from .captcha import CaptchaConfig
@@ -105,4 +106,5 @@ class HomeServerConfig(RootConfig):
         RedisConfig,
         ExperimentalConfig,
         BackgroundUpdateConfig,
+        AutoAcceptInvitesConfig,
     ]
diff --git a/synapse/events/auto_accept_invites.py b/synapse/events/auto_accept_invites.py
new file mode 100644
index 0000000000..d88ec51d9d
--- /dev/null
+++ b/synapse/events/auto_accept_invites.py
@@ -0,0 +1,196 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2021 The Matrix.org Foundation C.I.C
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import logging
+from http import HTTPStatus
+from typing import Any, Dict, Tuple
+
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.errors import SynapseError
+from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
+from synapse.module_api import EventBase, ModuleApi, run_as_background_process
+
+logger = logging.getLogger(__name__)
+
+
+class InviteAutoAccepter:
+    def __init__(self, config: AutoAcceptInvitesConfig, api: ModuleApi):
+        # Keep a reference to the Module API.
+        self._api = api
+        self._config = config
+
+        if not self._config.enabled:
+            return
+
+        should_run_on_this_worker = config.worker_to_run_on == self._api.worker_name
+
+        if not should_run_on_this_worker:
+            logger.info(
+                "Not accepting invites on this worker (configured: %r, here: %r)",
+                config.worker_to_run_on,
+                self._api.worker_name,
+            )
+            return
+
+        logger.info(
+            "Accepting invites on this worker (here: %r)", self._api.worker_name
+        )
+
+        # Register the callback.
+        self._api.register_third_party_rules_callbacks(
+            on_new_event=self.on_new_event,
+        )
+
+    async def on_new_event(self, event: EventBase, *args: Any) -> None:
+        """Listens for new events, and if the event is an invite for a local user then
+        automatically accepts it.
+
+        Args:
+            event: The incoming event.
+        """
+        # Check if the event is an invite for a local user.
+        is_invite_for_local_user = (
+            event.type == EventTypes.Member
+            and event.is_state()
+            and event.membership == Membership.INVITE
+            and self._api.is_mine(event.state_key)
+        )
+
+        # Only accept invites for direct messages if the configuration mandates it.
+        is_direct_message = event.content.get("is_direct", False)
+        is_allowed_by_direct_message_rules = (
+            not self._config.accept_invites_only_for_direct_messages
+            or is_direct_message is True
+        )
+
+        # Only accept invites from remote users if the configuration mandates it.
+        is_from_local_user = self._api.is_mine(event.sender)
+        is_allowed_by_local_user_rules = (
+            not self._config.accept_invites_only_from_local_users
+            or is_from_local_user is True
+        )
+
+        if (
+            is_invite_for_local_user
+            and is_allowed_by_direct_message_rules
+            and is_allowed_by_local_user_rules
+        ):
+            # Make the user join the room. We run this as a background process to circumvent a race condition
+            # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
+            run_as_background_process(
+                "retry_make_join",
+                self._retry_make_join,
+                event.state_key,
+                event.state_key,
+                event.room_id,
+                "join",
+                bg_start_span=False,
+            )
+
+            if is_direct_message:
+                # Mark this room as a direct message!
+                await self._mark_room_as_direct_message(
+                    event.state_key, event.sender, event.room_id
+                )
+
+    async def _mark_room_as_direct_message(
+        self, user_id: str, dm_user_id: str, room_id: str
+    ) -> None:
+        """
+        Marks a room (`room_id`) as a direct message with the counterparty `dm_user_id`
+        from the perspective of the user `user_id`.
+
+        Args:
+            user_id: the user for whom the membership is changing
+            dm_user_id: the user performing the membership change
+            room_id: room id of the room the user is invited to
+        """
+
+        # This is a dict of User IDs to tuples of Room IDs
+        # (get_global will return a frozendict of tuples as it freezes the data,
+        # but we should accept either frozen or unfrozen variants.)
+        # Be careful: we convert the outer frozendict into a dict here,
+        # but the contents of the dict are still frozen (tuples in lieu of lists,
+        # etc.)
+        dm_map: Dict[str, Tuple[str, ...]] = dict(
+            await self._api.account_data_manager.get_global(
+                user_id, AccountDataTypes.DIRECT
+            )
+            or {}
+        )
+
+        if dm_user_id not in dm_map:
+            dm_map[dm_user_id] = (room_id,)
+        else:
+            dm_rooms_for_user = dm_map[dm_user_id]
+            assert isinstance(dm_rooms_for_user, (tuple, list))
+
+            dm_map[dm_user_id] = tuple(dm_rooms_for_user) + (room_id,)
+
+        await self._api.account_data_manager.put_global(
+            user_id, AccountDataTypes.DIRECT, dm_map
+        )
+
+    async def _retry_make_join(
+        self, sender: str, target: str, room_id: str, new_membership: str
+    ) -> None:
+        """
+        A function to retry sending the `make_join` request with an increasing backoff. This is
+        implemented to work around a race condition when receiving invites over federation.
+
+        Args:
+            sender: the user performing the membership change
+            target: the user for whom the membership is changing
+            room_id: room id of the room to join to
+            new_membership: the type of membership event (in this case will be "join")
+        """
+
+        sleep = 0
+        retries = 0
+        join_event = None
+
+        while retries < 5:
+            try:
+                await self._api.sleep(sleep)
+                join_event = await self._api.update_room_membership(
+                    sender=sender,
+                    target=target,
+                    room_id=room_id,
+                    new_membership=new_membership,
+                )
+            except SynapseError as e:
+                if e.code == HTTPStatus.FORBIDDEN:
+                    logger.debug(
+                        f"Update_room_membership was forbidden. This can sometimes be expected for remote invites. Exception: {e}"
+                    )
+                else:
+                    logger.warn(
+                        f"Update_room_membership raised the following unexpected (SynapseError) exception: {e}"
+                    )
+            except Exception as e:
+                logger.warn(
+                    f"Update_room_membership raised the following unexpected exception: {e}"
+                )
+
+            sleep = 2**retries
+            retries += 1
+
+            if join_event is not None:
+                break
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index f275d4f35a..ee74289b6c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -817,7 +817,7 @@ class SsoHandler:
                 server_name = profile["avatar_url"].split("/")[-2]
                 media_id = profile["avatar_url"].split("/")[-1]
                 if self._is_mine_server_name(server_name):
-                    media = await self._media_repo.store.get_local_media(media_id)
+                    media = await self._media_repo.store.get_local_media(media_id)  # type: ignore[has-type]
                     if media is not None and upload_name == media.upload_name:
                         logger.info("skipping saving the user avatar")
                         return True