summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-02-22 16:10:10 +0100
committerGitHub <noreply@github.com>2022-02-22 15:10:10 +0000
commit250104d357c17a1c87fa46af35bbf3612f4ef171 (patch)
tree240530842209062c8366f22170dc0651d6f0fae5 /synapse/federation
parentPrune setup.cfg some more (#12059) (diff)
downloadsynapse-250104d357c17a1c87fa46af35bbf3612f4ef171.tar.xz
Implement account status endpoints (MSC3720) (#12001)
See matrix-org/matrix-doc#3720

Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py60
-rw-r--r--synapse/federation/transport/client.py19
-rw-r--r--synapse/federation/transport/server/__init__.py8
-rw-r--r--synapse/federation/transport/server/federation.py35
4 files changed, 120 insertions, 2 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c2997997da..2121e92e3a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,7 +56,7 @@ from synapse.api.room_versions import (
 from synapse.events import EventBase, builder
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.transport.client import SendJoinResponse
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
@@ -1610,6 +1610,64 @@ class FederationClient(FederationBase):
         except ValueError as e:
             raise InvalidResponseError(str(e))
 
+    async def get_account_status(
+        self, destination: str, user_ids: List[str]
+    ) -> Tuple[JsonDict, List[str]]:
+        """Retrieves account statuses for a given list of users on a given remote
+        homeserver.
+
+        If the request fails for any reason, all user IDs for this destination are marked
+        as failed.
+
+        Args:
+            destination: the destination to contact
+            user_ids: the user ID(s) for which to request account status(es)
+
+        Returns:
+            The account statuses, as well as the list of user IDs for which it was not
+            possible to retrieve a status.
+        """
+        try:
+            res = await self.transport_layer.get_account_status(destination, user_ids)
+        except Exception:
+            # If the query failed for any reason, mark all the users as failed.
+            return {}, user_ids
+
+        statuses = res.get("account_statuses", {})
+        failures = res.get("failures", [])
+
+        if not isinstance(statuses, dict) or not isinstance(failures, list):
+            # Make sure we're not feeding back malformed data back to the caller.
+            logger.warning(
+                "Destination %s responded with malformed data to account_status query",
+                destination,
+            )
+            return {}, user_ids
+
+        for user_id in user_ids:
+            # Any account whose status is missing is a user we failed to receive the
+            # status of.
+            if user_id not in statuses and user_id not in failures:
+                failures.append(user_id)
+
+        # Filter out any user ID that doesn't belong to the remote server that sent its
+        # status (or failure).
+        def filter_user_id(user_id: str) -> bool:
+            try:
+                return UserID.from_string(user_id).domain == destination
+            except SynapseError:
+                # If the user ID doesn't parse, ignore it.
+                return False
+
+        filtered_statuses = dict(
+            # item is a (key, value) tuple, so item[0] is the user ID.
+            filter(lambda item: filter_user_id(item[0]), statuses.items())
+        )
+
+        filtered_failures = list(filter(filter_user_id, failures))
+
+        return filtered_statuses, filtered_failures
+
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
 class TimestampToEventResponse:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 7e510e224a..69998de520 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -258,8 +258,9 @@ class TransportLayerClient:
         args: dict,
         retry_on_dns_fail: bool,
         ignore_backoff: bool = False,
+        prefix: str = FEDERATION_V1_PREFIX,
     ) -> JsonDict:
-        path = _create_v1_path("/query/%s", query_type)
+        path = _create_path(prefix, "/query/%s", query_type)
 
         return await self.client.get_json(
             destination=destination,
@@ -1247,6 +1248,22 @@ class TransportLayerClient:
             args={"suggested_only": "true" if suggested_only else "false"},
         )
 
+    async def get_account_status(
+        self, destination: str, user_ids: List[str]
+    ) -> JsonDict:
+        """
+        Args:
+            destination: The remote server.
+            user_ids: The user ID(s) for which to request account status(es).
+        """
+        path = _create_path(
+            FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc3720/account_status"
+        )
+
+        return await self.client.post_json(
+            destination=destination, path=path, data={"user_ids": user_ids}
+        )
+
 
 def _create_path(federation_prefix: str, path: str, *args: str) -> str:
     """
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index db4fe2c798..67a6347907 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -24,6 +24,7 @@ from synapse.federation.transport.server._base import (
 )
 from synapse.federation.transport.server.federation import (
     FEDERATION_SERVLET_CLASSES,
+    FederationAccountStatusServlet,
     FederationTimestampLookupServlet,
 )
 from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES
@@ -336,6 +337,13 @@ def register_servlets(
             ):
                 continue
 
+            # Only allow the `/account_status` servlet if msc3720 is enabled
+            if (
+                servletclass == FederationAccountStatusServlet
+                and not hs.config.experimental.msc3720_enabled
+            ):
+                continue
+
             servletclass(
                 hs=hs,
                 authenticator=authenticator,
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index e85a8eda5b..4d75e58bfc 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -766,6 +766,40 @@ class RoomComplexityServlet(BaseFederationServlet):
         return 200, complexity
 
 
+class FederationAccountStatusServlet(BaseFederationServerServlet):
+    PATH = "/query/account_status"
+    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3720"
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        authenticator: Authenticator,
+        ratelimiter: FederationRateLimiter,
+        server_name: str,
+    ):
+        super().__init__(hs, authenticator, ratelimiter, server_name)
+        self._account_handler = hs.get_account_handler()
+
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Mapping[bytes, Sequence[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
+        if "user_ids" not in content:
+            raise SynapseError(
+                400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
+            )
+
+        statuses, failures = await self._account_handler.get_account_statuses(
+            content["user_ids"],
+            allow_remote=False,
+        )
+
+        return 200, {"account_statuses": statuses, "failures": failures}
+
+
 FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationSendServlet,
     FederationEventServlet,
@@ -797,4 +831,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationRoomHierarchyUnstableServlet,
     FederationV1SendKnockServlet,
     FederationMakeKnockServlet,
+    FederationAccountStatusServlet,
 )