summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-03-30 08:39:38 -0400
committerGitHub <noreply@github.com>2023-03-30 08:39:38 -0400
commitae4acda1bb903f504e946442bfc66dd0e5757dad (patch)
tree37dfc78ecbbcfcc9363f3475aa7c4a207b0733b2 /synapse
parentFix spinloop during partial state sync when a prev event is in backoff (#15351) (diff)
downloadsynapse-ae4acda1bb903f504e946442bfc66dd0e5757dad.tar.xz
Implement MSC3984 to proxy /keys/query requests to appservices. (#15321)
If enabled, for users which are exclusively owned by an application
service then the appservice will be queried for devices in addition
to any information stored in the Synapse database.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/appservice/api.py54
-rw-r--r--synapse/config/experimental.py5
-rw-r--r--synapse/federation/federation_client.py48
-rw-r--r--synapse/handlers/appservice.py61
-rw-r--r--synapse/handlers/e2e_keys.py16
-rw-r--r--synapse/http/client.py38
6 files changed, 176 insertions, 46 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 51ee0e79df..b27eedef99 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -30,7 +30,7 @@ from prometheus_client import Counter
 from typing_extensions import TypeGuard
 
 from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
-from synapse.api.errors import CodeMessageException
+from synapse.api.errors import CodeMessageException, HttpResponseException
 from synapse.appservice import (
     ApplicationService,
     TransactionOneTimeKeysCount,
@@ -38,7 +38,7 @@ from synapse.appservice import (
 )
 from synapse.events import EventBase
 from synapse.events.utils import SerializeEventConfig, serialize_event
-from synapse.http.client import SimpleHttpClient
+from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
 from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -393,7 +393,11 @@ class ApplicationServiceApi(SimpleHttpClient):
     ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
         """Claim one time keys from an application service.
 
+        Note that any error (including a timeout) is treated as the application
+        service having no information.
+
         Args:
+            service: The application service to query.
             query: An iterable of tuples of (user ID, device ID, algorithm).
 
         Returns:
@@ -422,9 +426,9 @@ class ApplicationServiceApi(SimpleHttpClient):
                 body,
                 headers={"Authorization": [f"Bearer {service.hs_token}"]},
             )
-        except CodeMessageException as e:
+        except HttpResponseException as e:
             # The appservice doesn't support this endpoint.
-            if e.code == 404 or e.code == 405:
+            if is_unknown_endpoint(e):
                 return {}, query
             logger.warning("claim_keys to %s received %s", uri, e.code)
             return {}, query
@@ -444,6 +448,48 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         return response, missing
 
+    async def query_keys(
+        self, service: "ApplicationService", query: Dict[str, List[str]]
+    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+        """Query the application service for keys.
+
+        Note that any error (including a timeout) is treated as the application
+        service having no information.
+
+        Args:
+            service: The application service to query.
+            query: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            A map of device_keys/master_keys/self_signing_keys/user_signing_keys:
+
+            device_keys is a map of user ID -> a map device ID -> device info.
+        """
+        if service.url is None:
+            return {}
+
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
+        uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3984/keys/query"
+        try:
+            response = await self.post_json_get_json(
+                uri,
+                query,
+                headers={"Authorization": [f"Bearer {service.hs_token}"]},
+            )
+        except HttpResponseException as e:
+            # The appservice doesn't support this endpoint.
+            if is_unknown_endpoint(e):
+                return {}
+            logger.warning("query_keys to %s received %s", uri, e.code)
+            return {}
+        except Exception as ex:
+            logger.warning("query_keys to %s threw exception %s", uri, ex)
+            return {}
+
+        return response
+
     def _serialize(
         self, service: "ApplicationService", events: Iterable[EventBase]
     ) -> List[JsonDict]:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 53e6fc2b54..7687c80ea0 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -79,6 +79,11 @@ class ExperimentalConfig(Config):
             "msc3983_appservice_otk_claims", False
         )
 
+        # MSC3984: Proxying key queries to exclusive ASes.
+        self.msc3984_appservice_key_query: bool = experimental.get(
+            "msc3984_appservice_key_query", False
+        )
+
         # MSC3706 (server-side support for partial state in /send_join responses)
         # Synapse will always serve partial state responses to requests using the stable
         # query parameter `omit_members`. If this flag is set, Synapse will also serve
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7d04560dca..4cf4957a42 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -61,6 +61,7 @@ from synapse.federation.federation_base import (
     event_from_pdu_json,
 )
 from synapse.federation.transport.client import SendJoinResponse
+from synapse.http.client import is_unknown_endpoint
 from synapse.http.types import QueryParams
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
 from synapse.types import JsonDict, UserID, get_domain_from_id
@@ -759,43 +760,6 @@ class FederationClient(FederationBase):
 
         return signed_auth
 
-    def _is_unknown_endpoint(
-        self, e: HttpResponseException, synapse_error: Optional[SynapseError] = None
-    ) -> bool:
-        """
-        Returns true if the response was due to an endpoint being unimplemented.
-
-        Args:
-            e: The error response received from the remote server.
-            synapse_error: The above error converted to a SynapseError. This is
-                automatically generated if not provided.
-
-        """
-        if synapse_error is None:
-            synapse_error = e.to_synapse_error()
-        # MSC3743 specifies that servers should return a 404 or 405 with an errcode
-        # of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
-        # to an unknown method, respectively.
-        #
-        # Older versions of servers don't properly handle this. This needs to be
-        # rather specific as some endpoints truly do return 404 errors.
-        return (
-            # 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
-            (e.code == 404 or e.code == 405)
-            and (
-                # Older Dendrites returned a text or empty body.
-                # Older Conduit returned an empty body.
-                not e.response
-                or e.response == b"404 page not found"
-                # The proper response JSON with M_UNRECOGNIZED errcode.
-                or synapse_error.errcode == Codes.UNRECOGNIZED
-            )
-        ) or (
-            # Older Synapses returned a 400 error.
-            e.code == 400
-            and synapse_error.errcode == Codes.UNRECOGNIZED
-        )
-
     async def _try_destination_list(
         self,
         description: str,
@@ -887,7 +851,7 @@ class FederationClient(FederationBase):
                 elif 400 <= e.code < 500 and synapse_error.errcode in failover_errcodes:
                     failover = True
 
-                elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
+                elif failover_on_unknown_endpoint and is_unknown_endpoint(
                     e, synapse_error
                 ):
                     failover = True
@@ -1223,7 +1187,7 @@ class FederationClient(FederationBase):
             # If an error is received that is due to an unrecognised endpoint,
             # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
             # and raise.
-            if not self._is_unknown_endpoint(e):
+            if not is_unknown_endpoint(e):
                 raise
 
         logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
@@ -1297,7 +1261,7 @@ class FederationClient(FederationBase):
             # fallback to the v1 endpoint if the room uses old-style event IDs.
             # Otherwise, consider it a legitimate error and raise.
             err = e.to_synapse_error()
-            if self._is_unknown_endpoint(e, err):
+            if is_unknown_endpoint(e, err):
                 if room_version.event_format != EventFormatVersions.ROOM_V1_V2:
                     raise SynapseError(
                         400,
@@ -1358,7 +1322,7 @@ class FederationClient(FederationBase):
             # If an error is received that is due to an unrecognised endpoint,
             # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
             # and raise.
-            if not self._is_unknown_endpoint(e):
+            if not is_unknown_endpoint(e):
                 raise
 
         logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
@@ -1629,7 +1593,7 @@ class FederationClient(FederationBase):
                 # If an error is received that is due to an unrecognised endpoint,
                 # fallback to the unstable endpoint. Otherwise, consider it a
                 # legitimate error and raise.
-                if not self._is_unknown_endpoint(e):
+                if not is_unknown_endpoint(e):
                     raise
 
                 logger.debug(
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 953df4d9cd..da887647d4 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -18,6 +18,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Tuple,
     Union,
@@ -846,6 +847,10 @@ class ApplicationServicesHandler:
     ]:
         """Claim one time keys from application services.
 
+        Users which are exclusively owned by an application service are sent a
+        key claim request to check if the application service provides keys
+        directly.
+
         Args:
             query: An iterable of tuples of (user ID, device ID, algorithm).
 
@@ -901,3 +906,59 @@ class ApplicationServicesHandler:
                 missing.extend(result[1])
 
         return claimed_keys, missing
+
+    async def query_keys(
+        self, query: Mapping[str, Optional[List[str]]]
+    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+        """Query application services for device keys.
+
+        Users which are exclusively owned by an application service are queried
+        for keys to check if the application service provides keys directly.
+
+        Args:
+            query: map from user_id to a list of devices to query
+
+        Returns:
+            A map from user_id -> device_id -> device details
+        """
+        services = self.store.get_app_services()
+
+        # Partition the users by appservice.
+        query_by_appservice: Dict[str, Dict[str, List[str]]] = {}
+        for user_id, device_ids in query.items():
+            if not self.store.get_if_app_services_interested_in_user(user_id):
+                continue
+
+            # Find the associated appservice.
+            for service in services:
+                if service.is_exclusive_user(user_id):
+                    query_by_appservice.setdefault(service.id, {})[user_id] = (
+                        device_ids or []
+                    )
+                    continue
+
+        # Query each service in parallel.
+        results = await make_deferred_yieldable(
+            defer.DeferredList(
+                [
+                    run_in_background(
+                        self.appservice_api.query_keys,
+                        # We know this must be an app service.
+                        self.store.get_app_service_by_id(service_id),  # type: ignore[arg-type]
+                        service_query,
+                    )
+                    for service_id, service_query in query_by_appservice.items()
+                ],
+                consumeErrors=True,
+            )
+        )
+
+        # Patch together the results -- they are all independent (since they
+        # require exclusive control over the users). They get returned as a single
+        # dictionary.
+        key_queries: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        for success, result in results:
+            if success:
+                key_queries.update(result)
+
+        return key_queries
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 9e7c2c45b5..0073667470 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -91,6 +91,9 @@ class E2eKeysHandler:
         self._query_appservices_for_otks = (
             hs.config.experimental.msc3983_appservice_otk_claims
         )
+        self._query_appservices_for_keys = (
+            hs.config.experimental.msc3984_appservice_key_query
+        )
 
     @trace
     @cancellable
@@ -497,6 +500,19 @@ class E2eKeysHandler:
             local_query, include_displaynames
         )
 
+        # Check if the application services have any additional results.
+        if self._query_appservices_for_keys:
+            # Query the appservices for any keys.
+            appservice_results = await self._appservice_handler.query_keys(query)
+
+            # Merge results, overriding with what the appservice returned.
+            for user_id, devices in appservice_results.get("device_keys", {}).items():
+                # Copy the appservice device info over the homeserver device info, but
+                # don't completely overwrite it.
+                results.setdefault(user_id, {}).update(devices)
+
+            # TODO Handle cross-signing keys.
+
         # Build the result structure
         for user_id, device_keys in results.items():
             for device_id, device_info in device_keys.items():
diff --git a/synapse/http/client.py b/synapse/http/client.py
index d777d59ccf..5ee55981d9 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -966,3 +966,41 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
 
     def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory:
         return self
+
+
+def is_unknown_endpoint(
+    e: HttpResponseException, synapse_error: Optional[SynapseError] = None
+) -> bool:
+    """
+    Returns true if the response was due to an endpoint being unimplemented.
+
+    Args:
+        e: The error response received from the remote server.
+        synapse_error: The above error converted to a SynapseError. This is
+            automatically generated if not provided.
+
+    """
+    if synapse_error is None:
+        synapse_error = e.to_synapse_error()
+    # MSC3743 specifies that servers should return a 404 or 405 with an errcode
+    # of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
+    # to an unknown method, respectively.
+    #
+    # Older versions of servers don't properly handle this. This needs to be
+    # rather specific as some endpoints truly do return 404 errors.
+    return (
+        # 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
+        (e.code == 404 or e.code == 405)
+        and (
+            # Older Dendrites returned a text body or empty body.
+            # Older Conduit returned an empty body.
+            not e.response
+            or e.response == b"404 page not found"
+            # The proper response JSON with M_UNRECOGNIZED errcode.
+            or synapse_error.errcode == Codes.UNRECOGNIZED
+        )
+    ) or (
+        # Older Synapses returned a 400 error.
+        e.code == 400
+        and synapse_error.errcode == Codes.UNRECOGNIZED
+    )