summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2023-03-09 10:18:42 +0000
committerAndrew Morgan <andrew@amorgan.xyz>2023-03-09 16:50:31 +0000
commit1b30b82ac694e62d6346d3510b3e31f3fe03071b (patch)
treed9fc0fbc90b775326347c2d5603be001b5a04b27
parentMove callback-related code from ThirdPartyEventRules to its own class (diff)
downloadsynapse-1b30b82ac694e62d6346d3510b3e31f3fe03071b.tar.xz
Move callback-related code from the PresenceRouter to its own class
-rw-r--r--synapse/app/_base.py4
-rw-r--r--synapse/events/presence_router.py111
-rw-r--r--synapse/module_api/__init__.py13
-rw-r--r--synapse/module_api/callbacks/__init__.py2
-rw-r--r--synapse/module_api/callbacks/presence_router_callbacks.py122
-rw-r--r--tests/events/test_presence_router.py5
-rw-r--r--tests/server.py4
7 files changed, 147 insertions, 114 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 7ac23bef86..950b8b7389 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -58,7 +58,6 @@ from synapse.config._base import format_config_error
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig, ManholeConfig
 from synapse.crypto import context_factory
-from synapse.events.presence_router import load_legacy_presence_router
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.http.site import SynapseSite
 from synapse.logging.context import PreserveLoggingContext
@@ -66,6 +65,9 @@ from synapse.logging.opentracing import init_tracer
 from synapse.metrics import install_gc_manager, register_threadpool
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
+from synapse.module_api.callbacks.presence_router_callbacks import (
+    load_legacy_presence_router,
+)
 from synapse.module_api.callbacks.spam_checker_callbacks import (
     load_legacy_spam_checkers,
 )
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
index bb4a6bd957..1f1d89fd7e 100644
--- a/synapse/events/presence_router.py
+++ b/synapse/events/presence_router.py
@@ -12,93 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Awaitable,
-    Callable,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Set,
-    TypeVar,
-    Union,
-)
-
-from typing_extensions import ParamSpec
+from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
 
 from twisted.internet.defer import CancelledError
 
 from synapse.api.presence import UserPresenceState
-from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
+from synapse.util.async_helpers import delay_cancellation
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
-GET_USERS_FOR_STATES_CALLBACK = Callable[
-    [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
-]
-# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
-GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
-
 logger = logging.getLogger(__name__)
 
 
-P = ParamSpec("P")
-R = TypeVar("R")
-
-
-def load_legacy_presence_router(hs: "HomeServer") -> None:
-    """Wrapper that loads a presence router module configured using the old
-    configuration, and registers the hooks they implement.
-    """
-
-    if hs.config.server.presence_router_module_class is None:
-        return
-
-    module = hs.config.server.presence_router_module_class
-    config = hs.config.server.presence_router_config
-    api = hs.get_module_api()
-
-    presence_router = module(config=config, module_api=api)
-
-    # The known hooks. If a module implements a method which name appears in this set,
-    # we'll want to register it.
-    presence_router_methods = {
-        "get_users_for_states",
-        "get_interested_users",
-    }
-
-    # All methods that the module provides should be async, but this wasn't enforced
-    # in the old module system, so we wrap them if needed
-    def async_wrapper(
-        f: Optional[Callable[P, R]]
-    ) -> Optional[Callable[P, Awaitable[R]]]:
-        # f might be None if the callback isn't implemented by the module. In this
-        # case we don't want to register a callback at all so we return None.
-        if f is None:
-            return None
-
-        def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
-            # Assertion required because mypy can't prove we won't change `f`
-            # back to `None`. See
-            # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
-            assert f is not None
-
-            return maybe_awaitable(f(*args, **kwargs))
-
-        return run
-
-    # Register the hooks through the module API.
-    hooks: Dict[str, Optional[Callable[..., Any]]] = {
-        hook: async_wrapper(getattr(presence_router, hook, None))
-        for hook in presence_router_methods
-    }
-
-    api.register_presence_router_callbacks(**hooks)
-
-
 class PresenceRouter:
     """
     A module that the homeserver will call upon to help route user presence updates to
@@ -108,30 +34,7 @@ class PresenceRouter:
     ALL_USERS = "ALL"
 
     def __init__(self, hs: "HomeServer"):
-        # Initially there are no callbacks
-        self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
-        self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
-
-    def register_presence_router_callbacks(
-        self,
-        get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
-        get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
-    ) -> None:
-        # PresenceRouter modules are required to implement both of these methods
-        # or neither of them as they are assumed to act in a complementary manner
-        paired_methods = [get_users_for_states, get_interested_users]
-        if paired_methods.count(None) == 1:
-            raise RuntimeError(
-                "PresenceRouter modules must register neither or both of the paired callbacks: "
-                "[get_users_for_states, get_interested_users]"
-            )
-
-        # Append the methods provided to the lists of callbacks
-        if get_users_for_states is not None:
-            self._get_users_for_states_callbacks.append(get_users_for_states)
-
-        if get_interested_users is not None:
-            self._get_interested_users_callbacks.append(get_interested_users)
+        self._module_api_callbacks = hs.get_module_api_callbacks().presence_router
 
     async def get_users_for_states(
         self,
@@ -150,13 +53,13 @@ class PresenceRouter:
         """
 
         # Bail out early if we don't have any callbacks to run.
-        if len(self._get_users_for_states_callbacks) == 0:
+        if len(self._module_api_callbacks.get_users_for_states_callbacks) == 0:
             # Don't include any extra destinations for presence updates
             return {}
 
         users_for_states: Dict[str, Set[UserPresenceState]] = {}
         # run all the callbacks for get_users_for_states and combine the results
-        for callback in self._get_users_for_states_callbacks:
+        for callback in self._module_api_callbacks.get_users_for_states_callbacks:
             try:
                 # Note: result is an object here, because we don't trust modules to
                 # return the types they're supposed to.
@@ -206,13 +109,13 @@ class PresenceRouter:
         """
 
         # Bail out early if we don't have any callbacks to run.
-        if len(self._get_interested_users_callbacks) == 0:
+        if len(self._module_api_callbacks.get_interested_users_callbacks) == 0:
             # Don't report any additional interested users
             return set()
 
         interested_users = set()
         # run all the callbacks for get_interested_users and combine the results
-        for callback in self._get_interested_users_callbacks:
+        for callback in self._module_api_callbacks.get_interested_users_callbacks:
             try:
                 result = await delay_cancellation(callback(user_id))
             except CancelledError:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2dff8c457e..7a25a47ac5 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -39,11 +39,7 @@ from twisted.web.resource import Resource
 from synapse.api import errors
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
-from synapse.events.presence_router import (
-    GET_INTERESTED_USERS_CALLBACK,
-    GET_USERS_FOR_STATES_CALLBACK,
-    PresenceRouter,
-)
+from synapse.events.presence_router import PresenceRouter
 from synapse.events.spamcheck import SpamChecker
 from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
 from synapse.handlers.auth import (
@@ -78,6 +74,10 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
     ON_LEGACY_SEND_MAIL_CALLBACK,
     ON_USER_REGISTRATION_CALLBACK,
 )
+from synapse.module_api.callbacks.presence_router_callbacks import (
+    GET_INTERESTED_USERS_CALLBACK,
+    GET_USERS_FOR_STATES_CALLBACK,
+)
 from synapse.module_api.callbacks.spam_checker_callbacks import (
     CHECK_EVENT_FOR_SPAM_CALLBACK,
     CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
@@ -272,7 +272,6 @@ class ModuleApi:
         self._account_data_manager = AccountDataManager(hs)
 
         self._password_auth_provider = hs.get_password_auth_provider()
-        self._presence_router = hs.get_presence_router()
         self._account_data_handler = hs.get_account_data_handler()
 
     #################################################################################
@@ -393,7 +392,7 @@ class ModuleApi:
 
         Added in Synapse v1.42.0.
         """
-        return self._presence_router.register_presence_router_callbacks(
+        return self._callbacks.presence_router.register_callbacks(
             get_users_for_states=get_users_for_states,
             get_interested_users=get_interested_users,
         )
diff --git a/synapse/module_api/callbacks/__init__.py b/synapse/module_api/callbacks/__init__.py
index 6d17aef9f8..50b5f2f4d7 100644
--- a/synapse/module_api/callbacks/__init__.py
+++ b/synapse/module_api/callbacks/__init__.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from .account_validity_callbacks import AccountValidityModuleApiCallbacks
+from .presence_router_callbacks import PresenceRouterModuleApiCallbacks
 from .spam_checker_callbacks import SpamCheckerModuleApiCallbacks
 from .third_party_event_rules_callbacks import ThirdPartyEventRulesModuleApiCallbacks
 
@@ -26,3 +27,4 @@ class ModuleApiCallbacks:
         self.account_validity = AccountValidityModuleApiCallbacks()
         self.spam_checker = SpamCheckerModuleApiCallbacks()
         self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks()
+        self.presence_router = PresenceRouterModuleApiCallbacks()
diff --git a/synapse/module_api/callbacks/presence_router_callbacks.py b/synapse/module_api/callbacks/presence_router_callbacks.py
new file mode 100644
index 0000000000..5eb7f2fb69
--- /dev/null
+++ b/synapse/module_api/callbacks/presence_router_callbacks.py
@@ -0,0 +1,122 @@
+# Copyright 2021, 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    TypeVar,
+    Union,
+)
+
+from typing_extensions import ParamSpec
+
+from synapse.api.presence import UserPresenceState
+from synapse.util.async_helpers import maybe_awaitable
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+GET_USERS_FOR_STATES_CALLBACK = Callable[
+    [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
+]
+# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
+GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
+
+
+P = ParamSpec("P")
+R = TypeVar("R")
+
+
+def load_legacy_presence_router(hs: "HomeServer") -> None:
+    """Wrapper that loads a presence router module configured using the old
+    configuration, and registers the hooks they implement.
+    """
+
+    if hs.config.server.presence_router_module_class is None:
+        return
+
+    module = hs.config.server.presence_router_module_class
+    config = hs.config.server.presence_router_config
+    api = hs.get_module_api()
+
+    presence_router = module(config=config, module_api=api)
+
+    # The known hooks. If a module implements a method which name appears in this set,
+    # we'll want to register it.
+    presence_router_methods = {
+        "get_users_for_states",
+        "get_interested_users",
+    }
+
+    # All methods that the module provides should be async, but this wasn't enforced
+    # in the old module system, so we wrap them if needed
+    def async_wrapper(
+        f: Optional[Callable[P, R]]
+    ) -> Optional[Callable[P, Awaitable[R]]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
+            # Assertion required because mypy can't prove we won't change `f`
+            # back to `None`. See
+            # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
+            assert f is not None
+
+            return maybe_awaitable(f(*args, **kwargs))
+
+        return run
+
+    # Register the hooks through the module API.
+    hooks: Dict[str, Optional[Callable[..., Any]]] = {
+        hook: async_wrapper(getattr(presence_router, hook, None))
+        for hook in presence_router_methods
+    }
+
+    api.register_presence_router_callbacks(**hooks)
+
+
+class PresenceRouterModuleApiCallbacks:
+    def __init__(self) -> None:
+        # Initially there are no callbacks
+        self.get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
+        self.get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
+
+    def register_callbacks(
+        self,
+        get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
+        get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
+    ) -> None:
+        # PresenceRouter modules are required to implement both of these methods
+        # or neither of them as they are assumed to act in a complementary manner
+        paired_methods = [get_users_for_states, get_interested_users]
+        if paired_methods.count(None) == 1:
+            raise RuntimeError(
+                "PresenceRouter modules must register neither or both of the paired callbacks: "
+                "[get_users_for_states, get_interested_users]"
+            )
+
+        # Append the methods provided to the lists of callbacks
+        if get_users_for_states is not None:
+            self.get_users_for_states_callbacks.append(get_users_for_states)
+
+        if get_interested_users is not None:
+            self.get_interested_users_callbacks.append(get_interested_users)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 6fb1f1bd6e..bd6f87faba 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -19,10 +19,13 @@ import attr
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EduTypes
-from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
+from synapse.events.presence_router import PresenceRouter
 from synapse.federation.units import Transaction
 from synapse.handlers.presence import UserPresenceState
 from synapse.module_api import ModuleApi
+from synapse.module_api.callbacks.presence_router_callbacks import (
+    load_legacy_presence_router,
+)
 from synapse.rest import admin
 from synapse.rest.client import login, presence, room
 from synapse.server import HomeServer
diff --git a/tests/server.py b/tests/server.py
index 512dda55d7..787f558069 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -71,10 +71,12 @@ from twisted.web.server import Request, Site
 
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
-from synapse.events.presence_router import load_legacy_presence_router
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import ContextResourceUsage
+from synapse.module_api.callbacks.presence_router_callbacks import (
+    load_legacy_presence_router,
+)
 from synapse.module_api.callbacks.spam_checker_callbacks import (
     load_legacy_spam_checkers,
 )