summary refs log tree commit diff
path: root/synapse/events/presence_router.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events/presence_router.py')
-rw-r--r--synapse/events/presence_router.py192
1 files changed, 156 insertions, 36 deletions
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
index 6c37c8a7a4..eb4556cdc1 100644
--- a/synapse/events/presence_router.py
+++ b/synapse/events/presence_router.py
@@ -11,45 +11,115 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
+import logging
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Union,
+)
 
 from synapse.api.presence import UserPresenceState
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+GET_USERS_FOR_STATES_CALLBACK = Callable[
+    [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
+]
+GET_INTERESTED_USERS_CALLBACK = Callable[
+    [str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
+]
+
+logger = logging.getLogger(__name__)
+
+
+def load_legacy_presence_router(hs: "HomeServer"):
+    """Wrapper that loads a presence router module configured using the old
+    configuration, and registers the hooks they implement.
+    """
+
+    if hs.config.presence_router_module_class is None:
+        return
+
+    module = hs.config.presence_router_module_class
+    config = hs.config.presence_router_config
+    api = hs.get_module_api()
+
+    presence_router = module(config=config, module_api=api)
+
+    # The known hooks. If a module implements a method which name appears in this set,
+    # we'll want to register it.
+    presence_router_methods = {
+        "get_users_for_states",
+        "get_interested_users",
+    }
+
+    # All methods that the module provides should be async, but this wasn't enforced
+    # in the old module system, so we wrap them if needed
+    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        def run(*args, **kwargs):
+            # mypy doesn't do well across function boundaries so we need to tell it
+            # f is definitely not None.
+            assert f is not None
+
+            return maybe_awaitable(f(*args, **kwargs))
+
+        return run
+
+    # Register the hooks through the module API.
+    hooks = {
+        hook: async_wrapper(getattr(presence_router, hook, None))
+        for hook in presence_router_methods
+    }
+
+    api.register_presence_router_callbacks(**hooks)
+
 
 class PresenceRouter:
     """
     A module that the homeserver will call upon to help route user presence updates to
-    additional destinations. If a custom presence router is configured, calls will be
-    passed to that instead.
+    additional destinations.
     """
 
     ALL_USERS = "ALL"
 
     def __init__(self, hs: "HomeServer"):
-        self.custom_presence_router = None
+        # Initially there are no callbacks
+        self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
+        self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
 
-        # Check whether a custom presence router module has been configured
-        if hs.config.presence_router_module_class:
-            # Initialise the module
-            self.custom_presence_router = hs.config.presence_router_module_class(
-                config=hs.config.presence_router_config, module_api=hs.get_module_api()
+    def register_presence_router_callbacks(
+        self,
+        get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
+        get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
+    ):
+        # PresenceRouter modules are required to implement both of these methods
+        # or neither of them as they are assumed to act in a complementary manner
+        paired_methods = [get_users_for_states, get_interested_users]
+        if paired_methods.count(None) == 1:
+            raise RuntimeError(
+                "PresenceRouter modules must register neither or both of the paired callbacks: "
+                "[get_users_for_states, get_interested_users]"
             )
 
-            # Ensure the module has implemented the required methods
-            required_methods = ["get_users_for_states", "get_interested_users"]
-            for method_name in required_methods:
-                if not hasattr(self.custom_presence_router, method_name):
-                    raise Exception(
-                        "PresenceRouter module '%s' must implement all required methods: %s"
-                        % (
-                            hs.config.presence_router_module_class.__name__,
-                            ", ".join(required_methods),
-                        )
-                    )
+        # Append the methods provided to the lists of callbacks
+        if get_users_for_states is not None:
+            self._get_users_for_states_callbacks.append(get_users_for_states)
+
+        if get_interested_users is not None:
+            self._get_interested_users_callbacks.append(get_interested_users)
 
     async def get_users_for_states(
         self,
@@ -66,14 +136,40 @@ class PresenceRouter:
           A dictionary of user_id -> set of UserPresenceState, indicating which
           presence updates each user should receive.
         """
-        if self.custom_presence_router is not None:
-            # Ask the custom module
-            return await self.custom_presence_router.get_users_for_states(
-                state_updates=state_updates
-            )
 
-        # Don't include any extra destinations for presence updates
-        return {}
+        # Bail out early if we don't have any callbacks to run.
+        if len(self._get_users_for_states_callbacks) == 0:
+            # Don't include any extra destinations for presence updates
+            return {}
+
+        users_for_states = {}
+        # run all the callbacks for get_users_for_states and combine the results
+        for callback in self._get_users_for_states_callbacks:
+            try:
+                result = await callback(state_updates)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
+
+            if not isinstance(result, Dict):
+                logger.warning(
+                    "Wrong type returned by module API callback %s: %s, expected Dict",
+                    callback,
+                    result,
+                )
+                continue
+
+            for key, new_entries in result.items():
+                if not isinstance(new_entries, Set):
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected Set",
+                        callback,
+                        new_entries,
+                    )
+                    break
+                users_for_states.setdefault(key, set()).update(new_entries)
+
+        return users_for_states
 
     async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
         """
@@ -92,12 +188,36 @@ class PresenceRouter:
             A set of user IDs to return presence updates for, or ALL_USERS to return all
             known updates.
         """
-        if self.custom_presence_router is not None:
-            # Ask the custom module for interested users
-            return await self.custom_presence_router.get_interested_users(
-                user_id=user_id
-            )
 
-        # A custom presence router is not defined.
-        # Don't report any additional interested users
-        return set()
+        # Bail out early if we don't have any callbacks to run.
+        if len(self._get_interested_users_callbacks) == 0:
+            # Don't report any additional interested users
+            return set()
+
+        interested_users = set()
+        # run all the callbacks for get_interested_users and combine the results
+        for callback in self._get_interested_users_callbacks:
+            try:
+                result = await callback(user_id)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
+
+            # If one of the callbacks returns ALL_USERS then we can stop calling all
+            # of the other callbacks, since the set of interested_users is already as
+            # large as it can possibly be
+            if result == PresenceRouter.ALL_USERS:
+                return PresenceRouter.ALL_USERS
+
+            if not isinstance(result, Set):
+                logger.warning(
+                    "Wrong type returned by module API callback %s: %s, expected set",
+                    callback,
+                    result,
+                )
+                continue
+
+            # Add the new interested users to the set
+            interested_users.update(result)
+
+        return interested_users