summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-03-28 14:26:27 -0400
committerGitHub <noreply@github.com>2023-03-28 18:26:27 +0000
commit5282ba1e2bbff2635dc09aec45fd42a56c1a4545 (patch)
tree94377879ae342e639bb05c2257765c7f94bc048e
parentSpeed up generate sample config CI lint (#15340) (diff)
downloadsynapse-5282ba1e2bbff2635dc09aec45fd42a56c1a4545.tar.xz
Implement MSC3983 to proxy /keys/claim queries to appservices. (#15314)
Experimental support for MSC3983 is behind a configuration flag.
If enabled, for users which are exclusively owned by an application
service then the appservice will be queried for one-time keys *if*
there are none uploaded to Synapse.
-rw-r--r--changelog.d/15314.feature1
-rw-r--r--synapse/appservice/api.py56
-rw-r--r--synapse/config/experimental.py5
-rw-r--r--synapse/federation/federation_server.py20
-rw-r--r--synapse/handlers/appservice.py74
-rw-r--r--synapse/handlers/e2e_keys.py57
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py36
-rw-r--r--tests/appservice/test_api.py59
-rw-r--r--tests/handlers/test_e2e_keys.py76
9 files changed, 355 insertions, 29 deletions
diff --git a/changelog.d/15314.feature b/changelog.d/15314.feature
new file mode 100644
index 0000000000..68b289b0cc
--- /dev/null
+++ b/changelog.d/15314.feature
@@ -0,0 +1 @@
+Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)).
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 4812fb4496..51ee0e79df 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -388,6 +388,62 @@ class ApplicationServiceApi(SimpleHttpClient):
         failed_transactions_counter.labels(service.id).inc()
         return False
 
+    async def claim_client_keys(
+        self, service: "ApplicationService", query: List[Tuple[str, str, str]]
+    ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+        """Claim one time keys from an application service.
+
+        Args:
+            query: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            A tuple of:
+                A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
+
+                A copy of the input which has not been fulfilled because the
+                appservice doesn't support this endpoint or has not returned
+                data for that tuple.
+        """
+        if service.url is None:
+            return {}, query
+
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
+        # Create the expected payload shape.
+        body: Dict[str, Dict[str, List[str]]] = {}
+        for user_id, device, algorithm in query:
+            body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
+
+        uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
+        try:
+            response = await self.post_json_get_json(
+                uri,
+                body,
+                headers={"Authorization": [f"Bearer {service.hs_token}"]},
+            )
+        except CodeMessageException as e:
+            # The appservice doesn't support this endpoint.
+            if e.code == 404 or e.code == 405:
+                return {}, query
+            logger.warning("claim_keys to %s received %s", uri, e.code)
+            return {}, query
+        except Exception as ex:
+            logger.warning("claim_keys to %s threw exception %s", uri, ex)
+            return {}, query
+
+        # Check if the appservice fulfilled all of the queried user/device/algorithms
+        # or if some are still missing.
+        #
+        # TODO This places a lot of faith in the response shape being correct.
+        missing = [
+            (user_id, device, algorithm)
+            for user_id, device, algorithm in query
+            if algorithm not in response.get(user_id, {}).get(device, [])
+        ]
+
+        return response, missing
+
     def _serialize(
         self, service: "ApplicationService", events: Iterable[EventBase]
     ) -> List[JsonDict]:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 99dcd27c74..53e6fc2b54 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -74,6 +74,11 @@ class ExperimentalConfig(Config):
             "msc3202_transaction_extensions", False
         )
 
+        # MSC3983: Proxying OTK claim requests to exclusive ASes.
+        self.msc3983_appservice_otk_claims: bool = experimental.get(
+            "msc3983_appservice_otk_claims", False
+        )
+
         # MSC3706 (server-side support for partial state in /send_join responses)
         # Synapse will always serve partial state responses to requests using the stable
         # query parameter `omit_members`. If this flag is set, Synapse will also serve
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 6d99845de5..64e99292ec 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -86,7 +86,7 @@ from synapse.storage.databases.main.lock import Lock
 from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
 from synapse.storage.roommember import MemberSummary
 from synapse.types import JsonDict, StateMap, get_domain_from_id
-from synapse.util import json_decoder, unwrapFirstError
+from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import parse_server_name
@@ -135,6 +135,7 @@ class FederationServer(FederationBase):
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
         self._room_member_handler = hs.get_room_member_handler()
+        self._e2e_keys_handler = hs.get_e2e_keys_handler()
 
         self._state_storage_controller = hs.get_storage_controllers().state
 
@@ -1012,15 +1013,14 @@ class FederationServer(FederationBase):
                 query.append((user_id, device_id, algorithm))
 
         log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
-        results = await self.store.claim_e2e_one_time_keys(query)
-
-        json_result: Dict[str, Dict[str, dict]] = {}
-        for user_id, device_keys in results.items():
-            for device_id, keys in device_keys.items():
-                for key_id, json_str in keys.items():
-                    json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json_decoder.decode(json_str)
-                    }
+        results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
+
+        json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        for result in results:
+            for user_id, device_keys in result.items():
+                for device_id, keys in device_keys.items():
+                    for key_id, key in keys.items():
+                        json_result.setdefault(user_id, {})[device_id] = {key_id: key}
 
         logger.info(
             "Claimed one-time-keys: %s",
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index ec3ab968e9..953df4d9cd 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,7 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 from prometheus_client import Counter
 
@@ -829,3 +838,66 @@ class ApplicationServicesHandler:
         if unknown_user:
             return await self.query_user_exists(user_id)
         return True
+
+    async def claim_e2e_one_time_keys(
+        self, query: Iterable[Tuple[str, str, str]]
+    ) -> Tuple[
+        Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
+    ]:
+        """Claim one time keys from application services.
+
+        Args:
+            query: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            A tuple of:
+                An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+
+                A copy of the input which has not been fulfilled (either because
+                they are not appservice users or the appservice does not support
+                providing OTKs).
+        """
+        services = self.store.get_app_services()
+
+        # Partition the users by appservice.
+        query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
+        missing = []
+        for user_id, device, algorithm in query:
+            if not self.store.get_if_app_services_interested_in_user(user_id):
+                missing.append((user_id, device, algorithm))
+                continue
+
+            # Find the associated appservice.
+            for service in services:
+                if service.is_exclusive_user(user_id):
+                    query_by_appservice.setdefault(service.id, []).append(
+                        (user_id, device, algorithm)
+                    )
+                    continue
+
+        # Query each service in parallel.
+        results = await make_deferred_yieldable(
+            defer.DeferredList(
+                [
+                    run_in_background(
+                        self.appservice_api.claim_client_keys,
+                        # We know this must be an app service.
+                        self.store.get_app_service_by_id(service_id),  # type: ignore[arg-type]
+                        service_query,
+                    )
+                    for service_id, service_query in query_by_appservice.items()
+                ],
+                consumeErrors=True,
+            )
+        )
+
+        # Patch together the results -- they are all independent (since they
+        # require exclusive control over the users). They get returned as a list
+        # and the caller combines them.
+        claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
+        for success, result in results:
+            if success:
+                claimed_keys.append(result[0])
+                missing.extend(result[1])
+
+        return claimed_keys, missing
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 4e9c8d8db0..9e7c2c45b5 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
 
@@ -53,6 +52,7 @@ class E2eKeysHandler:
         self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
+        self._appservice_handler = hs.get_application_service_handler()
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
@@ -88,6 +88,10 @@ class E2eKeysHandler:
             max_count=10,
         )
 
+        self._query_appservices_for_otks = (
+            hs.config.experimental.msc3983_appservice_otk_claims
+        )
+
     @trace
     @cancellable
     async def query_devices(
@@ -542,6 +546,42 @@ class E2eKeysHandler:
 
         return ret
 
+    async def claim_local_one_time_keys(
+        self, local_query: List[Tuple[str, str, str]]
+    ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
+        """Claim one time keys for local users.
+
+        1. Attempt to claim OTKs from the database.
+        2. Ask application services if they provide OTKs.
+        3. Attempt to fetch fallback keys from the database.
+
+        Args:
+            local_query: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+        """
+
+        otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
+
+        # If the application services have not provided any keys via the C-S
+        # API, query it directly for one-time keys.
+        if self._query_appservices_for_otks:
+            (
+                appservice_results,
+                not_found,
+            ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
+        else:
+            appservice_results = []
+
+        # For each user that does not have a one-time keys available, see if
+        # there is a fallback key.
+        fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
+
+        # Return the results in order, each item from the input query should
+        # only appear once in the combined list.
+        return (otk_results, *appservice_results, fallback_results)
+
     @trace
     async def claim_one_time_keys(
         self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
@@ -561,17 +601,18 @@ class E2eKeysHandler:
         set_tag("local_key_query", str(local_query))
         set_tag("remote_key_query", str(remote_queries))
 
-        results = await self.store.claim_e2e_one_time_keys(local_query)
+        results = await self.claim_local_one_time_keys(local_query)
 
         # A map of user ID -> device ID -> key ID -> key.
         json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        for result in results:
+            for user_id, device_keys in result.items():
+                for device_id, keys in device_keys.items():
+                    for key_id, key in keys.items():
+                        json_result.setdefault(user_id, {})[device_id] = {key_id: key}
+
+        # Remote failures.
         failures: Dict[str, JsonDict] = {}
-        for user_id, device_keys in results.items():
-            for device_id, keys in device_keys.items():
-                for key_id, json_str in keys.items():
-                    json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json_decoder.decode(json_str)
-                    }
 
         @trace
         async def claim_client_keys(destination: str) -> None:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a3b6c8ae8e..dc7768c50c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -51,7 +51,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
@@ -1028,14 +1028,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
-    ) -> Dict[str, Dict[str, Dict[str, str]]]:
+    ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
         """Take a list of one time keys out of the database.
 
         Args:
             query_list: An iterable of tuples of (user ID, device ID, algorithm).
 
         Returns:
-            A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+            A tuple pf:
+                A map of user ID -> a map device ID -> a map of key ID -> JSON.
+
+                A copy of the input which has not been fulfilled.
         """
 
         @trace
@@ -1115,7 +1118,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             key_id, key_json = otk_row
             return f"{algorithm}:{key_id}", key_json
 
-        results: Dict[str, Dict[str, Dict[str, str]]] = {}
+        results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        missing: List[Tuple[str, str, str]] = []
         for user_id, device_id, algorithm in query_list:
             if self.database_engine.supports_returning:
                 # If we support RETURNING clause we can use a single query that
@@ -1138,11 +1142,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                device_results[claim_row[0]] = claim_row[1]
-                continue
+                device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+            else:
+                missing.append((user_id, device_id, algorithm))
+
+        return results, missing
+
+    async def claim_e2e_fallback_keys(
+        self, query_list: Iterable[Tuple[str, str, str]]
+    ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+        """Take a list of fallback keys out of the database.
 
-            # No one-time key available, so see if there's a fallback
-            # key
+        Args:
+            query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            A map of user ID -> a map device ID -> a map of key ID -> JSON.
+        """
+        results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        for user_id, device_id, algorithm in query_list:
             row = await self.db_pool.simple_select_one(
                 table="e2e_fallback_keys_json",
                 keyvalues={
@@ -1179,7 +1197,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 )
 
             device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
-            device_results[f"{algorithm}:{key_id}"] = key_json
+            device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
 
         return results
 
diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py
index 9d183b733e..0dd02b7d58 100644
--- a/tests/appservice/test_api.py
+++ b/tests/appservice/test_api.py
@@ -105,3 +105,62 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(self.request_url, URL_LOCATION)
         self.assertEqual(result, SUCCESS_RESULT_LOCATION)
+
+    def test_claim_keys(self) -> None:
+        """
+        Tests that the /keys/claim response is properly parsed for missing
+        keys.
+        """
+
+        RESPONSE: JsonDict = {
+            "@alice:example.org": {
+                "DEVICE_1": {
+                    "signed_curve25519:AAAAHg": {
+                        # We don't really care about the content of the keys,
+                        # they get passed back transparently.
+                    },
+                    "signed_curve25519:BBBBHg": {},
+                },
+                "DEVICE_2": {"signed_curve25519:CCCCHg": {}},
+            },
+        }
+
+        async def post_json_get_json(
+            uri: str,
+            post_json: Any,
+            headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
+        ) -> JsonDict:
+            # Ensure the access token is passed as both a header and query arg.
+            if not headers.get("Authorization"):
+                raise RuntimeError("Access token not provided")
+
+            self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
+            return RESPONSE
+
+        # We assign to a method, which mypy doesn't like.
+        self.api.post_json_get_json = Mock(side_effect=post_json_get_json)  # type: ignore[assignment]
+
+        MISSING_KEYS = [
+            # Known user, known device, missing algorithm.
+            ("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"),
+            # Known user, missing device.
+            ("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"),
+            # Unknown user.
+            ("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"),
+        ]
+
+        claimed_keys, missing = self.get_success(
+            self.api.claim_client_keys(
+                self.service,
+                [
+                    # Found devices
+                    ("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"),
+                    ("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"),
+                    ("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"),
+                ]
+                + MISSING_KEYS,
+            )
+        )
+
+        self.assertEqual(claimed_keys, RESPONSE)
+        self.assertEqual(missing, MISSING_KEYS)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 6b4cba65d0..4ff04fc66b 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -23,18 +23,24 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.handlers.device import DeviceHandler
 from synapse.server import HomeServer
+from synapse.storage.databases.main.appservice import _make_exclusive_regex
 from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
+from tests.unittest import override_config
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        return self.setup_test_homeserver(federation_client=mock.Mock())
+        self.appservice_api = mock.Mock()
+        return self.setup_test_homeserver(
+            federation_client=mock.Mock(), application_service_api=self.appservice_api
+        )
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_e2e_keys_handler()
@@ -941,3 +947,71 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
             # The two requests to the local homeserver should be identical.
             self.assertEqual(response_1, response_2)
+
+    @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
+    def test_query_appservice(self) -> None:
+        local_user = "@boris:" + self.hs.hostname
+        device_id_1 = "xyz"
+        fallback_key = {"alg1:k1": "fallback_key1"}
+        device_id_2 = "abc"
+        otk = {"alg1:k2": "key2"}
+
+        # Inject an appservice interested in this user.
+        appservice = ApplicationService(
+            token="i_am_an_app_service",
+            id="1234",
+            namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]},
+            # Note: this user does not have to match the regex above
+            sender="@as_main:test",
+        )
+        self.hs.get_datastores().main.services_cache = [appservice]
+        self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+            [appservice]
+        )
+
+        # Setup a response, but only for device 2.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")])
+        )
+
+        # we shouldn't have any unused fallback keys yet
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(res, [])
+
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id_1,
+                {"fallback_keys": fallback_key},
+            )
+        )
+
+        # we should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # claiming an OTK when no OTKs are available should ask the appservice, then
+        # query the fallback keys.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {
+                    "one_time_keys": {
+                        local_user: {device_id_1: "alg1", device_id_2: "alg1"}
+                    }
+                },
+                timeout=None,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {
+                    local_user: {device_id_1: fallback_key, device_id_2: otk}
+                },
+            },
+        )