summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/appservice/api.py54
-rw-r--r--synapse/config/experimental.py5
-rw-r--r--synapse/federation/federation_client.py48
-rw-r--r--synapse/handlers/appservice.py61
-rw-r--r--synapse/handlers/e2e_keys.py16
-rw-r--r--synapse/http/client.py38
6 files changed, 176 insertions, 46 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 51ee0e79df..b27eedef99 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -30,7 +30,7 @@ from prometheus_client import Counter
 from typing_extensions import TypeGuard
 
 from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
-from synapse.api.errors import CodeMessageException
+from synapse.api.errors import CodeMessageException, HttpResponseException
 from synapse.appservice import (
     ApplicationService,
     TransactionOneTimeKeysCount,
@@ -38,7 +38,7 @@ from synapse.appservice import (
 )
 from synapse.events import EventBase
 from synapse.events.utils import SerializeEventConfig, serialize_event
-from synapse.http.client import SimpleHttpClient
+from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
 from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -393,7 +393,11 @@ class ApplicationServiceApi(SimpleHttpClient):
     ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
         """Claim one time keys from an application service.
 
+        Note that any error (including a timeout) is treated as the application
+        service having no information.
+
         Args:
+            service: The application service to query.
             query: An iterable of tuples of (user ID, device ID, algorithm).
 
         Returns:
@@ -422,9 +426,9 @@ class ApplicationServiceApi(SimpleHttpClient):
                 body,
                 headers={"Authorization": [f"Bearer {service.hs_token}"]},
             )
-        except CodeMessageException as e:
+        except HttpResponseException as e:
             # The appservice doesn't support this endpoint.
-            if e.code == 404 or e.code == 405:
+            if is_unknown_endpoint(e):
                 return {}, query
             logger.warning("claim_keys to %s received %s", uri, e.code)
             return {}, query
@@ -444,6 +448,48 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         return response, missing
 
+    async def query_keys(
+        self, service: "ApplicationService", query: Dict[str, List[str]]
+    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+        """Query the application service for keys.
+
+        Note that any error (including a timeout) is treated as the application
+        service having no information.
+
+        Args:
+            service: The application service to query.
+            query: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            A map of device_keys/master_keys/self_signing_keys/user_signing_keys:
+
+            device_keys is a map of user ID -> a map device ID -> device info.
+        """
+        if service.url is None:
+            return {}
+
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
+        uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3984/keys/query"
+        try:
+            response = await self.post_json_get_json(
+                uri,
+                query,
+                headers={"Authorization": [f"Bearer {service.hs_token}"]},
+            )
+        except HttpResponseException as e:
+            # The appservice doesn't support this endpoint.
+            if is_unknown_endpoint(e):
+                return {}
+            logger.warning("query_keys to %s received %s", uri, e.code)
+            return {}
+        except Exception as ex:
+            logger.warning("query_keys to %s threw exception %s", uri, ex)
+            return {}
+
+        return response
+
     def _serialize(
         self, service: "ApplicationService", events: Iterable[EventBase]
     ) -> List[JsonDict]:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 53e6fc2b54..7687c80ea0 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -79,6 +79,11 @@ class ExperimentalConfig(Config):
             "msc3983_appservice_otk_claims", False
         )
 
+        # MSC3984: Proxying key queries to exclusive ASes.
+        self.msc3984_appservice_key_query: bool = experimental.get(
+            "msc3984_appservice_key_query", False
+        )
+
         # MSC3706 (server-side support for partial state in /send_join responses)
         # Synapse will always serve partial state responses to requests using the stable
         # query parameter `omit_members`. If this flag is set, Synapse will also serve
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7d04560dca..4cf4957a42 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -61,6 +61,7 @@ from synapse.federation.federation_base import (
     event_from_pdu_json,
 )
 from synapse.federation.transport.client import SendJoinResponse
+from synapse.http.client import is_unknown_endpoint
 from synapse.http.types import QueryParams
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
 from synapse.types import JsonDict, UserID, get_domain_from_id
@@ -759,43 +760,6 @@ class FederationClient(FederationBase):
 
         return signed_auth
 
-    def _is_unknown_endpoint(
-        self, e: HttpResponseException, synapse_error: Optional[SynapseError] = None
-    ) -> bool:
-        """
-        Returns true if the response was due to an endpoint being unimplemented.
-
-        Args:
-            e: The error response received from the remote server.
-            synapse_error: The above error converted to a SynapseError. This is
-                automatically generated if not provided.
-
-        """
-        if synapse_error is None:
-            synapse_error = e.to_synapse_error()
-        # MSC3743 specifies that servers should return a 404 or 405 with an errcode
-        # of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
-        # to an unknown method, respectively.
-        #
-        # Older versions of servers don't properly handle this. This needs to be
-        # rather specific as some endpoints truly do return 404 errors.
-        return (
-            # 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
-            (e.code == 404 or e.code == 405)
-            and (
-                # Older Dendrites returned a text or empty body.
-                # Older Conduit returned an empty body.
-                not e.response
-                or e.response == b"404 page not found"
-                # The proper response JSON with M_UNRECOGNIZED errcode.
-                or synapse_error.errcode == Codes.UNRECOGNIZED
-            )
-        ) or (
-            # Older Synapses returned a 400 error.
-            e.code == 400
-            and synapse_error.errcode == Codes.UNRECOGNIZED
-        )
-
     async def _try_destination_list(
         self,
         description: str,
@@ -887,7 +851,7 @@ class FederationClient(FederationBase):
                 elif 400 <= e.code < 500 and synapse_error.errcode in failover_errcodes:
                     failover = True
 
-                elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
+                elif failover_on_unknown_endpoint and is_unknown_endpoint(
                     e, synapse_error
                 ):
                     failover = True
@@ -1223,7 +1187,7 @@ class FederationClient(FederationBase):
             # If an error is received that is due to an unrecognised endpoint,
             # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
             # and raise.
-            if not self._is_unknown_endpoint(e):
+            if not is_unknown_endpoint(e):
                 raise
 
         logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
@@ -1297,7 +1261,7 @@ class FederationClient(FederationBase):
             # fallback to the v1 endpoint if the room uses old-style event IDs.
             # Otherwise, consider it a legitimate error and raise.
             err = e.to_synapse_error()
-            if self._is_unknown_endpoint(e, err):
+            if is_unknown_endpoint(e, err):
                 if room_version.event_format != EventFormatVersions.ROOM_V1_V2:
                     raise SynapseError(
                         400,
@@ -1358,7 +1322,7 @@ class FederationClient(FederationBase):
             # If an error is received that is due to an unrecognised endpoint,
             # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
             # and raise.
-            if not self._is_unknown_endpoint(e):
+            if not is_unknown_endpoint(e):
                 raise
 
         logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
@@ -1629,7 +1593,7 @@ class FederationClient(FederationBase):
                 # If an error is received that is due to an unrecognised endpoint,
                 # fallback to the unstable endpoint. Otherwise, consider it a
                 # legitimate error and raise.
-                if not self._is_unknown_endpoint(e):
+                if not is_unknown_endpoint(e):
                     raise
 
                 logger.debug(
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 953df4d9cd..da887647d4 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -18,6 +18,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Tuple,
     Union,
@@ -846,6 +847,10 @@ class ApplicationServicesHandler:
     ]:
         """Claim one time keys from application services.
 
+        Users which are exclusively owned by an application service are sent a
+        key claim request to check if the application service provides keys
+        directly.
+
         Args:
             query: An iterable of tuples of (user ID, device ID, algorithm).
 
@@ -901,3 +906,59 @@ class ApplicationServicesHandler:
                 missing.extend(result[1])
 
         return claimed_keys, missing
+
+    async def query_keys(
+        self, query: Mapping[str, Optional[List[str]]]
+    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+        """Query application services for device keys.
+
+        Users which are exclusively owned by an application service are queried
+        for keys to check if the application service provides keys directly.
+
+        Args:
+            query: map from user_id to a list of devices to query
+
+        Returns:
+            A map from user_id -> device_id -> device details
+        """
+        services = self.store.get_app_services()
+
+        # Partition the users by appservice.
+        query_by_appservice: Dict[str, Dict[str, List[str]]] = {}
+        for user_id, device_ids in query.items():
+            if not self.store.get_if_app_services_interested_in_user(user_id):
+                continue
+
+            # Find the associated appservice.
+            for service in services:
+                if service.is_exclusive_user(user_id):
+                    query_by_appservice.setdefault(service.id, {})[user_id] = (
+                        device_ids or []
+                    )
+                    continue
+
+        # Query each service in parallel.
+        results = await make_deferred_yieldable(
+            defer.DeferredList(
+                [
+                    run_in_background(
+                        self.appservice_api.query_keys,
+                        # We know this must be an app service.
+                        self.store.get_app_service_by_id(service_id),  # type: ignore[arg-type]
+                        service_query,
+                    )
+                    for service_id, service_query in query_by_appservice.items()
+                ],
+                consumeErrors=True,
+            )
+        )
+
+        # Patch together the results -- they are all independent (since they
+        # require exclusive control over the users). They get returned as a single
+        # dictionary.
+        key_queries: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        for success, result in results:
+            if success:
+                key_queries.update(result)
+
+        return key_queries
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 9e7c2c45b5..0073667470 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -91,6 +91,9 @@ class E2eKeysHandler:
         self._query_appservices_for_otks = (
             hs.config.experimental.msc3983_appservice_otk_claims
         )
+        self._query_appservices_for_keys = (
+            hs.config.experimental.msc3984_appservice_key_query
+        )
 
     @trace
     @cancellable
@@ -497,6 +500,19 @@ class E2eKeysHandler:
             local_query, include_displaynames
         )
 
+        # Check if the application services have any additional results.
+        if self._query_appservices_for_keys:
+            # Query the appservices for any keys.
+            appservice_results = await self._appservice_handler.query_keys(query)
+
+            # Merge results, overriding with what the appservice returned.
+            for user_id, devices in appservice_results.get("device_keys", {}).items():
+                # Copy the appservice device info over the homeserver device info, but
+                # don't completely overwrite it.
+                results.setdefault(user_id, {}).update(devices)
+
+            # TODO Handle cross-signing keys.
+
         # Build the result structure
         for user_id, device_keys in results.items():
             for device_id, device_info in device_keys.items():
diff --git a/synapse/http/client.py b/synapse/http/client.py
index d777d59ccf..5ee55981d9 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -966,3 +966,41 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
 
     def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory:
         return self
+
+
+def is_unknown_endpoint(
+    e: HttpResponseException, synapse_error: Optional[SynapseError] = None
+) -> bool:
+    """
+    Returns true if the response was due to an endpoint being unimplemented.
+
+    Args:
+        e: The error response received from the remote server.
+        synapse_error: The above error converted to a SynapseError. This is
+            automatically generated if not provided.
+
+    """
+    if synapse_error is None:
+        synapse_error = e.to_synapse_error()
+    # MSC3743 specifies that servers should return a 404 or 405 with an errcode
+    # of M_UNRECOGNIZED when they receive a request to an unknown endpoint or
+    # to an unknown method, respectively.
+    #
+    # Older versions of servers don't properly handle this. This needs to be
+    # rather specific as some endpoints truly do return 404 errors.
+    return (
+        # 404 is an unknown endpoint, 405 is a known endpoint, but unknown method.
+        (e.code == 404 or e.code == 405)
+        and (
+            # Older Dendrites returned a text body or empty body.
+            # Older Conduit returned an empty body.
+            not e.response
+            or e.response == b"404 page not found"
+            # The proper response JSON with M_UNRECOGNIZED errcode.
+            or synapse_error.errcode == Codes.UNRECOGNIZED
+        )
+    ) or (
+        # Older Synapses returned a 400 error.
+        e.code == 400
+        and synapse_error.errcode == Codes.UNRECOGNIZED
+    )