summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-04-25 13:30:41 -0400
committerGitHub <noreply@github.com>2023-04-25 13:30:41 -0400
commit8e9739449dd6d3c133adf9e995d27d06518a0bcf (patch)
treefc3a5f11b23315b18ea87b0217b1dbd954c058ed
parentMerge branch 'master' into develop (diff)
downloadsynapse-8e9739449dd6d3c133adf9e995d27d06518a0bcf.tar.xz
Add unstable /keys/claim endpoint which always returns fallback keys. (#15462)
It can be useful to always return the fallback key when attempting to
claim keys. This adds an unstable endpoint for `/keys/claim` which
always returns fallback keys in addition to one-time-keys.

The fallback key(s) are not marked as "used" unless there are no
corresponding OTKs.

This is currently defined in MSC3983 (although likely to be split out
to a separate MSC). The endpoint shape may change or be requested
differently (i.e. a keyword parameter on the current endpoint), but the
core logic should be reasonable.
-rw-r--r--changelog.d/15462.misc1
-rw-r--r--synapse/federation/federation_server.py6
-rw-r--r--synapse/federation/transport/server/__init__.py6
-rw-r--r--synapse/federation/transport/server/federation.py23
-rw-r--r--synapse/handlers/appservice.py13
-rw-r--r--synapse/handlers/e2e_keys.py70
-rw-r--r--synapse/rest/client/keys.py31
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py9
-rw-r--r--tests/handlers/test_e2e_keys.py241
9 files changed, 371 insertions, 29 deletions
diff --git a/changelog.d/15462.misc b/changelog.d/15462.misc
new file mode 100644
index 0000000000..36e4bffbc8
--- /dev/null
+++ b/changelog.d/15462.misc
@@ -0,0 +1 @@
+Update support for [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d7740eb3b4..c618f3d7a6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1005,7 +1005,7 @@ class FederationServer(FederationBase):
 
     @trace
     async def on_claim_client_keys(
-        self, origin: str, content: JsonDict
+        self, origin: str, content: JsonDict, always_include_fallback_keys: bool
     ) -> Dict[str, Any]:
         query = []
         for user_id, device_keys in content.get("one_time_keys", {}).items():
@@ -1013,7 +1013,9 @@ class FederationServer(FederationBase):
                 query.append((user_id, device_id, algorithm))
 
         log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
-        results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
+        results = await self._e2e_keys_handler.claim_local_one_time_keys(
+            query, always_include_fallback_keys=always_include_fallback_keys
+        )
 
         json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for result in results:
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index 753372fc54..55d2cd0a9a 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -25,6 +25,7 @@ from synapse.federation.transport.server._base import (
 from synapse.federation.transport.server.federation import (
     FEDERATION_SERVLET_CLASSES,
     FederationAccountStatusServlet,
+    FederationUnstableClientKeysClaimServlet,
 )
 from synapse.http.server import HttpServer, JsonResource
 from synapse.http.servlet import (
@@ -298,6 +299,11 @@ def register_servlets(
                 and not hs.config.experimental.msc3720_enabled
             ):
                 continue
+            if (
+                servletclass == FederationUnstableClientKeysClaimServlet
+                and not hs.config.experimental.msc3983_appservice_otk_claims
+            ):
+                continue
 
             servletclass(
                 hs=hs,
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index ec5b5eeafa..e2340d70d5 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -577,7 +577,28 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
     async def on_POST(
         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
     ) -> Tuple[int, JsonDict]:
-        response = await self.handler.on_claim_client_keys(origin, content)
+        response = await self.handler.on_claim_client_keys(
+            origin, content, always_include_fallback_keys=False
+        )
+        return 200, response
+
+
+class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
+    """
+    Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
+    always includes fallback keys in the response.
+    """
+
+    PREFIX = FEDERATION_UNSTABLE_PREFIX
+    PATH = "/user/keys/claim"
+    CATEGORY = "Federation requests"
+
+    async def on_POST(
+        self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+    ) -> Tuple[int, JsonDict]:
+        response = await self.handler.on_claim_client_keys(
+            origin, content, always_include_fallback_keys=True
+        )
         return 200, response
 
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index da887647d4..4ca2bc0420 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -842,9 +842,7 @@ class ApplicationServicesHandler:
 
     async def claim_e2e_one_time_keys(
         self, query: Iterable[Tuple[str, str, str]]
-    ) -> Tuple[
-        Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
-    ]:
+    ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
         """Claim one time keys from application services.
 
         Users which are exclusively owned by an application service are sent a
@@ -856,7 +854,7 @@ class ApplicationServicesHandler:
 
         Returns:
             A tuple of:
-                An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+                A map of user ID -> a map device ID -> a map of key ID -> JSON.
 
                 A copy of the input which has not been fulfilled (either because
                 they are not appservice users or the appservice does not support
@@ -897,12 +895,11 @@ class ApplicationServicesHandler:
         )
 
         # Patch together the results -- they are all independent (since they
-        # require exclusive control over the users). They get returned as a list
-        # and the caller combines them.
-        claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
+        # require exclusive control over the users, which is the outermost key).
+        claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for success, result in results:
             if success:
-                claimed_keys.append(result[0])
+                claimed_keys.update(result[0])
                 missing.extend(result[1])
 
         return claimed_keys, missing
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 0073667470..d1ab95126c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -563,7 +563,9 @@ class E2eKeysHandler:
         return ret
 
     async def claim_local_one_time_keys(
-        self, local_query: List[Tuple[str, str, str]]
+        self,
+        local_query: List[Tuple[str, str, str]],
+        always_include_fallback_keys: bool,
     ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
         """Claim one time keys for local users.
 
@@ -573,6 +575,7 @@ class E2eKeysHandler:
 
         Args:
             local_query: An iterable of tuples of (user ID, device ID, algorithm).
+            always_include_fallback_keys: True to always include fallback keys.
 
         Returns:
             An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
@@ -583,24 +586,73 @@ class E2eKeysHandler:
         # If the application services have not provided any keys via the C-S
         # API, query it directly for one-time keys.
         if self._query_appservices_for_otks:
+            # TODO Should this query for fallback keys of uploaded OTKs if
+            #      always_include_fallback_keys is True? The MSC is ambiguous.
             (
                 appservice_results,
                 not_found,
             ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
         else:
-            appservice_results = []
+            appservice_results = {}
+
+        # Calculate which user ID / device ID / algorithm tuples to get fallback
+        # keys for. This can be either only missing results *or* all results
+        # (which don't already have a fallback key).
+        if always_include_fallback_keys:
+            # Build the fallback query as any part of the original query where
+            # the appservice didn't respond with a fallback key.
+            fallback_query = []
+
+            # Iterate each item in the original query and search the results
+            # from the appservice for that user ID / device ID. If it is found,
+            # check if any of the keys match the requested algorithm & are a
+            # fallback key.
+            for user_id, device_id, algorithm in local_query:
+                # Check if the appservice responded for this query.
+                as_result = appservice_results.get(user_id, {}).get(device_id, {})
+                found_otk = False
+                for key_id, key_json in as_result.items():
+                    if key_id.startswith(f"{algorithm}:"):
+                        # A OTK or fallback key was found for this query.
+                        found_otk = True
+                        # A fallback key was found for this query, no need to
+                        # query further.
+                        if key_json.get("fallback", False):
+                            break
+
+                else:
+                    # No fallback key was found from appservices, query for it.
+                    # Only mark the fallback key as used if no OTK was found
+                    # (from either the database or appservices).
+                    mark_as_used = not found_otk and not any(
+                        key_id.startswith(f"{algorithm}:")
+                        for key_id in otk_results.get(user_id, {})
+                        .get(device_id, {})
+                        .keys()
+                    )
+                    fallback_query.append((user_id, device_id, algorithm, mark_as_used))
+
+        else:
+            # All fallback keys get marked as used.
+            fallback_query = [
+                (user_id, device_id, algorithm, True)
+                for user_id, device_id, algorithm in not_found
+            ]
 
         # For each user that does not have a one-time keys available, see if
         # there is a fallback key.
-        fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
+        fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
 
         # Return the results in order, each item from the input query should
         # only appear once in the combined list.
-        return (otk_results, *appservice_results, fallback_results)
+        return (otk_results, appservice_results, fallback_results)
 
     @trace
     async def claim_one_time_keys(
-        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
+        self,
+        query: Dict[str, Dict[str, Dict[str, str]]],
+        timeout: Optional[int],
+        always_include_fallback_keys: bool,
     ) -> JsonDict:
         local_query: List[Tuple[str, str, str]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@@ -617,7 +669,9 @@ class E2eKeysHandler:
         set_tag("local_key_query", str(local_query))
         set_tag("remote_key_query", str(remote_queries))
 
-        results = await self.claim_local_one_time_keys(local_query)
+        results = await self.claim_local_one_time_keys(
+            local_query, always_include_fallback_keys
+        )
 
         # A map of user ID -> device ID -> key ID -> key.
         json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
@@ -625,7 +679,9 @@ class E2eKeysHandler:
             for user_id, device_keys in result.items():
                 for device_id, keys in device_keys.items():
                     for key_id, key in keys.items():
-                        json_result.setdefault(user_id, {})[device_id] = {key_id: key}
+                        json_result.setdefault(user_id, {}).setdefault(
+                            device_id, {}
+                        ).update({key_id: key})
 
         # Remote failures.
         failures: Dict[str, JsonDict] = {}
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 6209b79b01..2a25094109 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+import re
 from typing import TYPE_CHECKING, Any, Optional, Tuple
 
 from synapse.api.errors import InvalidAPICallError, SynapseError
@@ -288,7 +289,33 @@ class OneTimeKeyServlet(RestServlet):
         await self.auth.get_user_by_req(request, allow_guest=True)
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
+        result = await self.e2e_keys_handler.claim_one_time_keys(
+            body, timeout, always_include_fallback_keys=False
+        )
+        return 200, result
+
+
+class UnstableOneTimeKeyServlet(RestServlet):
+    """
+    Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
+    fallback keys in the response.
+    """
+
+    PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
+    CATEGORY = "Encryption requests"
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.auth = hs.get_auth()
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.auth.get_user_by_req(request, allow_guest=True)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
+        body = parse_json_object_from_request(request)
+        result = await self.e2e_keys_handler.claim_one_time_keys(
+            body, timeout, always_include_fallback_keys=True
+        )
         return 200, result
 
 
@@ -394,6 +421,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     KeyQueryServlet(hs).register(http_server)
     KeyChangesServlet(hs).register(http_server)
     OneTimeKeyServlet(hs).register(http_server)
+    if hs.config.experimental.msc3983_appservice_otk_claims:
+        UnstableOneTimeKeyServlet(hs).register(http_server)
     if hs.config.worker.worker_app is None:
         SigningKeyUploadServlet(hs).register(http_server)
         SignaturesUploadServlet(hs).register(http_server)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index dc7768c50c..1a4ae55304 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1149,18 +1149,19 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         return results, missing
 
     async def claim_e2e_fallback_keys(
-        self, query_list: Iterable[Tuple[str, str, str]]
+        self, query_list: Iterable[Tuple[str, str, str, bool]]
     ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
         """Take a list of fallback keys out of the database.
 
         Args:
-            query_list: An iterable of tuples of (user ID, device ID, algorithm).
+            query_list: An iterable of tuples of
+                (user ID, device ID, algorithm, whether the key should be marked as used).
 
         Returns:
             A map of user ID -> a map device ID -> a map of key ID -> JSON.
         """
         results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
-        for user_id, device_id, algorithm in query_list:
+        for user_id, device_id, algorithm, mark_as_used in query_list:
             row = await self.db_pool.simple_select_one(
                 table="e2e_fallback_keys_json",
                 keyvalues={
@@ -1180,7 +1181,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             used = row["used"]
 
             # Mark fallback key as used if not already.
-            if not used:
+            if not used and mark_as_used:
                 await self.db_pool.simple_update_one(
                     table="e2e_fallback_keys_json",
                     keyvalues={
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 013b9ee550..18edebd652 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -160,7 +160,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         res2 = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -203,7 +205,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         # key
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -220,7 +224,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         # claiming an OTK again should return the same fallback key
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -267,7 +273,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -277,7 +285,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -296,7 +306,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         claim_res = self.get_success(
             self.handler.claim_one_time_keys(
-                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -304,6 +316,75 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
 
+    def test_fallback_key_always_returned(self) -> None:
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        fallback_key = {"alg1:k1": "fallback_key1"}
+        otk = {"alg1:k2": "key2"}
+
+        # we shouldn't have any unused fallback keys yet
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, [])
+
+        # Upload a OTK & fallback key.
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"one_time_keys": otk, "fallback_keys": fallback_key},
+            )
+        )
+
+        # we should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Claiming an OTK and requesting to always return the fallback key should
+        # return both.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
+            },
+        )
+
+        # This should not mark the key as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Claiming an OTK again should return only the fallback key.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+        # And mark it as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(fallback_res, [])
+
     def test_replace_master_key(self) -> None:
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
@@ -1004,6 +1085,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
                     }
                 },
                 timeout=None,
+                always_include_fallback_keys=False,
             )
         )
         self.assertEqual(
@@ -1016,6 +1098,153 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
         )
 
+    @override_config({"experimental_features": {"msc3983_appservice_otk_claims": True}})
+    def test_query_appservice_with_fallback(self) -> None:
+        local_user = "@boris:" + self.hs.hostname
+        device_id_1 = "xyz"
+        fallback_key = {"alg1:k1": {"desc": "fallback_key1", "fallback": True}}
+        otk = {"alg1:k2": {"desc": "key2"}}
+        as_fallback_key = {"alg1:k3": {"desc": "fallback_key3", "fallback": True}}
+        as_otk = {"alg1:k4": {"desc": "key4"}}
+
+        # Inject an appservice interested in this user.
+        appservice = ApplicationService(
+            token="i_am_an_app_service",
+            id="1234",
+            namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
+            # Note: this user does not have to match the regex above
+            sender="@as_main:test",
+        )
+        self.hs.get_datastores().main.services_cache = [appservice]
+        self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+            [appservice]
+        )
+
+        # Setup a response.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
+        )
+
+        # Claim OTKs, which will ask the appservice and do nothing else.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {
+                    local_user: {device_id_1: {**as_otk, **as_fallback_key}}
+                },
+            },
+        )
+
+        # Now upload a fallback key.
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(res, [])
+
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id_1,
+                {"fallback_keys": fallback_key},
+            )
+        )
+
+        # we should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # The appservice will return only the OTK.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_1: as_otk}}, [])
+        )
+
+        # Claim OTKs, which should return the OTK from the appservice and the
+        # uploaded fallback key.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {
+                    local_user: {device_id_1: {**as_otk, **fallback_key}}
+                },
+            },
+        )
+
+        # But the fallback key should not be marked as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Now upload a OTK.
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id_1,
+                {"one_time_keys": otk},
+            )
+        )
+
+        # Claim OTKs, which will return information only from the database.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id_1: {**otk, **fallback_key}}},
+            },
+        )
+
+        # But the fallback key should not be marked as used.
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id_1)
+        )
+        self.assertEqual(fallback_res, ["alg1"])
+
+        # Finally, return only the fallback key from the appservice.
+        self.appservice_api.claim_client_keys.return_value = make_awaitable(
+            ({local_user: {device_id_1: as_fallback_key}}, [])
+        )
+
+        # Claim OTKs, which will return only the fallback key from the database.
+        claim_res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id_1: "alg1"}}},
+                timeout=None,
+                always_include_fallback_keys=True,
+            )
+        )
+        self.assertEqual(
+            claim_res,
+            {
+                "failures": {},
+                "one_time_keys": {local_user: {device_id_1: as_fallback_key}},
+            },
+        )
+
     @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
     def test_query_local_devices_appservice(self) -> None:
         """Test that querying of appservices for keys overrides responses from the database."""