summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-04-27 12:57:46 -0400
committerGitHub <noreply@github.com>2023-04-27 12:57:46 -0400
commit57aeeb308b39c4fd455682966eabc9c0fa17c65d (patch)
tree3b59e2a367f7894a2adfca66c6579fe317723a39 /synapse
parentAdd type hints to schema deltas (#15497) (diff)
downloadsynapse-57aeeb308b39c4fd455682966eabc9c0fa17c65d.tar.xz
Add support for claiming multiple OTKs at once. (#15468)
MSC3983 provides a way to request multiple OTKs at once from appservices,
this extends this concept to the Client-Server API.

Note that this will likely be spit out into a separate MSC, but is currently part of
MSC3983.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/appservice/api.py31
-rw-r--r--synapse/federation/federation_client.py49
-rw-r--r--synapse/federation/federation_server.py7
-rw-r--r--synapse/federation/transport/client.py49
-rw-r--r--synapse/federation/transport/server/federation.py25
-rw-r--r--synapse/handlers/appservice.py14
-rw-r--r--synapse/handlers/e2e_keys.py31
-rw-r--r--synapse/rest/client/keys.py42
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py77
9 files changed, 251 insertions, 74 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 86ddb1bb28..024098e9cb 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -442,8 +442,10 @@ class ApplicationServiceApi(SimpleHttpClient):
         return False
 
     async def claim_client_keys(
-        self, service: "ApplicationService", query: List[Tuple[str, str, str]]
-    ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+        self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
+    ) -> Tuple[
+        Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+    ]:
         """Claim one time keys from an application service.
 
         Note that any error (including a timeout) is treated as the application
@@ -469,8 +471,10 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         # Create the expected payload shape.
         body: Dict[str, Dict[str, List[str]]] = {}
-        for user_id, device, algorithm in query:
-            body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
+        for user_id, device, algorithm, count in query:
+            body.setdefault(user_id, {}).setdefault(device, []).extend(
+                [algorithm] * count
+            )
 
         uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
         try:
@@ -493,11 +497,20 @@ class ApplicationServiceApi(SimpleHttpClient):
         # or if some are still missing.
         #
         # TODO This places a lot of faith in the response shape being correct.
-        missing = [
-            (user_id, device, algorithm)
-            for user_id, device, algorithm in query
-            if algorithm not in response.get(user_id, {}).get(device, [])
-        ]
+        missing = []
+        for user_id, device, algorithm, count in query:
+            # Count the number of keys in the response for this algorithm by
+            # checking which key IDs start with the algorithm. This uses that
+            # True == 1 in Python to generate a count.
+            response_count = sum(
+                key_id.startswith(f"{algorithm}:")
+                for key_id in response.get(user_id, {}).get(device, {})
+            )
+            count -= response_count
+            # If the appservice responds with fewer keys than requested, then
+            # consider the request unfulfilled.
+            if count > 0:
+                missing.append((user_id, device, algorithm, count))
 
         return response, missing
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ba34573d46..0b2d1a78f7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -235,7 +235,10 @@ class FederationClient(FederationBase):
         )
 
     async def claim_client_keys(
-        self, destination: str, content: JsonDict, timeout: Optional[int]
+        self,
+        destination: str,
+        query: Dict[str, Dict[str, Dict[str, int]]],
+        timeout: Optional[int],
     ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
@@ -247,6 +250,50 @@ class FederationClient(FederationBase):
             The JSON object from the response
         """
         sent_queries_counter.labels("client_one_time_keys").inc()
+
+        # Convert the query with counts into a stable and unstable query and check
+        # if attempting to claim more than 1 OTK.
+        content: Dict[str, Dict[str, str]] = {}
+        unstable_content: Dict[str, Dict[str, List[str]]] = {}
+        use_unstable = False
+        for user_id, one_time_keys in query.items():
+            for device_id, algorithms in one_time_keys.items():
+                if any(count > 1 for count in algorithms.values()):
+                    use_unstable = True
+                if algorithms:
+                    # For the stable query, choose only the first algorithm.
+                    content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
+                    # For the unstable query, repeat each algorithm by count, then
+                    # splat those into chain to get a flattened list of all algorithms.
+                    #
+                    # Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
+                    unstable_content.setdefault(user_id, {})[device_id] = list(
+                        itertools.chain(
+                            *(
+                                itertools.repeat(algorithm, count)
+                                for algorithm, count in algorithms.items()
+                            )
+                        )
+                    )
+
+        if use_unstable:
+            try:
+                return await self.transport_layer.claim_client_keys_unstable(
+                    destination, unstable_content, timeout
+                )
+            except HttpResponseException as e:
+                # If an error is received that is due to an unrecognised endpoint,
+                # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
+                # and raise.
+                if not is_unknown_endpoint(e):
+                    raise
+
+            logger.debug(
+                "Couldn't claim client keys with the unstable API, falling back to the v1 API"
+            )
+        else:
+            logger.debug("Skipping unstable claim client keys API")
+
         return await self.transport_layer.claim_client_keys(
             destination, content, timeout
         )
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index c618f3d7a6..ca43c7bfc0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1005,13 +1005,8 @@ class FederationServer(FederationBase):
 
     @trace
     async def on_claim_client_keys(
-        self, origin: str, content: JsonDict, always_include_fallback_keys: bool
+        self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
     ) -> Dict[str, Any]:
-        query = []
-        for user_id, device_keys in content.get("one_time_keys", {}).items():
-            for device_id, algorithm in device_keys.items():
-                query.append((user_id, device_id, algorithm))
-
         log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
         results = await self._e2e_keys_handler.claim_local_one_time_keys(
             query, always_include_fallback_keys=always_include_fallback_keys
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index bedbd23ded..bc70b94f68 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -650,10 +650,10 @@ class TransportLayerClient:
 
         Response:
             {
-              "device_keys": {
+              "one_time_keys": {
                 "<user_id>": {
                   "<device_id>": {
-                    "<algorithm>:<key_id>": "<key_base64>"
+                    "<algorithm>:<key_id>": <OTK JSON>
                   }
                 }
               }
@@ -669,7 +669,50 @@ class TransportLayerClient:
         path = _create_v1_path("/user/keys/claim")
 
         return await self.client.post_json(
-            destination=destination, path=path, data=query_content, timeout=timeout
+            destination=destination,
+            path=path,
+            data={"one_time_keys": query_content},
+            timeout=timeout,
+        )
+
+    async def claim_client_keys_unstable(
+        self, destination: str, query_content: JsonDict, timeout: Optional[int]
+    ) -> JsonDict:
+        """Claim one-time keys for a list of devices hosted on a remote server.
+
+        Request:
+            {
+              "one_time_keys": {
+                "<user_id>": {
+                  "<device_id>": {"<algorithm>": <count>}
+                }
+              }
+            }
+
+        Response:
+            {
+              "one_time_keys": {
+                "<user_id>": {
+                  "<device_id>": {
+                    "<algorithm>:<key_id>": <OTK JSON>
+                  }
+                }
+              }
+            }
+
+        Args:
+            destination: The server to query.
+            query_content: The user ids to query.
+        Returns:
+            A dict containing the one-time keys.
+        """
+        path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
+
+        return await self.client.post_json(
+            destination=destination,
+            path=path,
+            data={"one_time_keys": query_content},
+            timeout=timeout,
         )
 
     async def get_missing_events(
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index e2340d70d5..36b0362504 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -12,6 +12,7 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 import logging
+from collections import Counter
 from typing import (
     TYPE_CHECKING,
     Dict,
@@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
     async def on_POST(
         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
     ) -> Tuple[int, JsonDict]:
+        # Generate a count for each algorithm, which is hard-coded to 1.
+        key_query: List[Tuple[str, str, str, int]] = []
+        for user_id, device_keys in content.get("one_time_keys", {}).items():
+            for device_id, algorithm in device_keys.items():
+                key_query.append((user_id, device_id, algorithm, 1))
+
         response = await self.handler.on_claim_client_keys(
-            origin, content, always_include_fallback_keys=False
+            key_query, always_include_fallback_keys=False
         )
         return 200, response
 
 
 class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
     """
-    Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
-    always includes fallback keys in the response.
+    Identical to the stable endpoint (FederationClientKeysClaimServlet) except
+    it allows for querying for multiple OTKs at once and always includes fallback
+    keys in the response.
     """
 
     PREFIX = FEDERATION_UNSTABLE_PREFIX
@@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
     async def on_POST(
         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
     ) -> Tuple[int, JsonDict]:
+        # Generate a count for each algorithm.
+        key_query: List[Tuple[str, str, str, int]] = []
+        for user_id, device_keys in content.get("one_time_keys", {}).items():
+            for device_id, algorithms in device_keys.items():
+                counts = Counter(algorithms)
+                for algorithm, count in counts.items():
+                    key_query.append((user_id, device_id, algorithm, count))
+
         response = await self.handler.on_claim_client_keys(
-            origin, content, always_include_fallback_keys=True
+            key_query, always_include_fallback_keys=True
         )
         return 200, response
 
@@ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationClientKeysQueryServlet,
     FederationUserDevicesQueryServlet,
     FederationClientKeysClaimServlet,
+    FederationUnstableClientKeysClaimServlet,
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
     FederationVersionServlet,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 4ca2bc0420..6429545c98 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -841,8 +841,10 @@ class ApplicationServicesHandler:
         return True
 
     async def claim_e2e_one_time_keys(
-        self, query: Iterable[Tuple[str, str, str]]
-    ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+        self, query: Iterable[Tuple[str, str, str, int]]
+    ) -> Tuple[
+        Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+    ]:
         """Claim one time keys from application services.
 
         Users which are exclusively owned by an application service are sent a
@@ -863,18 +865,18 @@ class ApplicationServicesHandler:
         services = self.store.get_app_services()
 
         # Partition the users by appservice.
-        query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
+        query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
         missing = []
-        for user_id, device, algorithm in query:
+        for user_id, device, algorithm, count in query:
             if not self.store.get_if_app_services_interested_in_user(user_id):
-                missing.append((user_id, device, algorithm))
+                missing.append((user_id, device, algorithm, count))
                 continue
 
             # Find the associated appservice.
             for service in services:
                 if service.is_exclusive_user(user_id):
                     query_by_appservice.setdefault(service.id, []).append(
-                        (user_id, device, algorithm)
+                        (user_id, device, algorithm, count)
                     )
                     continue
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d1ab95126c..24741b667b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -564,7 +564,7 @@ class E2eKeysHandler:
 
     async def claim_local_one_time_keys(
         self,
-        local_query: List[Tuple[str, str, str]],
+        local_query: List[Tuple[str, str, str, int]],
         always_include_fallback_keys: bool,
     ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
         """Claim one time keys for local users.
@@ -581,6 +581,12 @@ class E2eKeysHandler:
             An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
         """
 
+        # Cap the number of OTKs that can be claimed at once to avoid abuse.
+        local_query = [
+            (user_id, device_id, algorithm, min(count, 5))
+            for user_id, device_id, algorithm, count in local_query
+        ]
+
         otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
 
         # If the application services have not provided any keys via the C-S
@@ -607,7 +613,7 @@ class E2eKeysHandler:
             # from the appservice for that user ID / device ID. If it is found,
             # check if any of the keys match the requested algorithm & are a
             # fallback key.
-            for user_id, device_id, algorithm in local_query:
+            for user_id, device_id, algorithm, _count in local_query:
                 # Check if the appservice responded for this query.
                 as_result = appservice_results.get(user_id, {}).get(device_id, {})
                 found_otk = False
@@ -630,13 +636,17 @@ class E2eKeysHandler:
                         .get(device_id, {})
                         .keys()
                     )
+                    # Note that it doesn't make sense to request more than 1 fallback key
+                    # per (user_id, device_id, algorithm).
                     fallback_query.append((user_id, device_id, algorithm, mark_as_used))
 
         else:
             # All fallback keys get marked as used.
             fallback_query = [
+                # Note that it doesn't make sense to request more than 1 fallback key
+                # per (user_id, device_id, algorithm).
                 (user_id, device_id, algorithm, True)
-                for user_id, device_id, algorithm in not_found
+                for user_id, device_id, algorithm, count in not_found
             ]
 
         # For each user that does not have a one-time keys available, see if
@@ -650,18 +660,19 @@ class E2eKeysHandler:
     @trace
     async def claim_one_time_keys(
         self,
-        query: Dict[str, Dict[str, Dict[str, str]]],
+        query: Dict[str, Dict[str, Dict[str, int]]],
         timeout: Optional[int],
         always_include_fallback_keys: bool,
     ) -> JsonDict:
-        local_query: List[Tuple[str, str, str]] = []
-        remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
+        local_query: List[Tuple[str, str, str, int]] = []
+        remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
 
-        for user_id, one_time_keys in query.get("one_time_keys", {}).items():
+        for user_id, one_time_keys in query.items():
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                for device_id, algorithm in one_time_keys.items():
-                    local_query.append((user_id, device_id, algorithm))
+                for device_id, algorithms in one_time_keys.items():
+                    for algorithm, count in algorithms.items():
+                        local_query.append((user_id, device_id, algorithm, count))
             else:
                 domain = get_domain_from_id(user_id)
                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys
@@ -692,7 +703,7 @@ class E2eKeysHandler:
             device_keys = remote_queries[destination]
             try:
                 remote_result = await self.federation.claim_client_keys(
-                    destination, {"one_time_keys": device_keys}, timeout=timeout
+                    destination, device_keys, timeout=timeout
                 )
                 for user_id, keys in remote_result["one_time_keys"].items():
                     if user_id in device_keys:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 2a25094109..9bbab5e624 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -16,7 +16,8 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
 
 from synapse.api.errors import InvalidAPICallError, SynapseError
 from synapse.http.server import HttpServer
@@ -289,16 +290,40 @@ class OneTimeKeyServlet(RestServlet):
         await self.auth.get_user_by_req(request, allow_guest=True)
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
+
+        # Generate a count for each algorithm, which is hard-coded to 1.
+        query: Dict[str, Dict[str, Dict[str, int]]] = {}
+        for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+            for device_id, algorithm in one_time_keys.items():
+                query.setdefault(user_id, {})[device_id] = {algorithm: 1}
+
         result = await self.e2e_keys_handler.claim_one_time_keys(
-            body, timeout, always_include_fallback_keys=False
+            query, timeout, always_include_fallback_keys=False
         )
         return 200, result
 
 
 class UnstableOneTimeKeyServlet(RestServlet):
     """
-    Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
-    fallback keys in the response.
+    Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
+    querying for multiple OTKs at once and always includes fallback keys in the
+    response.
+
+    POST /keys/claim HTTP/1.1
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": ["<algorithm>", ...]
+    } } }
+
+    HTTP/1.1 200 OK
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": {
+            "<algorithm>:<key_id>": "<key_base64>"
+    } } } }
+
     """
 
     PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
@@ -313,8 +338,15 @@ class UnstableOneTimeKeyServlet(RestServlet):
         await self.auth.get_user_by_req(request, allow_guest=True)
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
+
+        # Generate a count for each algorithm.
+        query: Dict[str, Dict[str, Dict[str, int]]] = {}
+        for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+            for device_id, algorithms in one_time_keys.items():
+                query.setdefault(user_id, {})[device_id] = Counter(algorithms)
+
         result = await self.e2e_keys_handler.claim_one_time_keys(
-            body, timeout, always_include_fallback_keys=True
+            query, timeout, always_include_fallback_keys=True
         )
         return 200, result
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1a4ae55304..4bc391f213 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1027,8 +1027,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         ...
 
     async def claim_e2e_one_time_keys(
-        self, query_list: Iterable[Tuple[str, str, str]]
-    ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+        self, query_list: Iterable[Tuple[str, str, str, int]]
+    ) -> Tuple[
+        Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
+    ]:
         """Take a list of one time keys out of the database.
 
         Args:
@@ -1043,8 +1045,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         @trace
         def _claim_e2e_one_time_key_simple(
-            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
-        ) -> Optional[Tuple[str, str]]:
+            txn: LoggingTransaction,
+            user_id: str,
+            device_id: str,
+            algorithm: str,
+            count: int,
+        ) -> List[Tuple[str, str]]:
             """Claim OTK for device for DBs that don't support RETURNING.
 
             Returns:
@@ -1055,36 +1061,41 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             sql = """
                 SELECT key_id, key_json FROM e2e_one_time_keys_json
                 WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                LIMIT 1
+                LIMIT ?
             """
 
-            txn.execute(sql, (user_id, device_id, algorithm))
-            otk_row = txn.fetchone()
-            if otk_row is None:
-                return None
+            txn.execute(sql, (user_id, device_id, algorithm, count))
+            otk_rows = list(txn)
+            if not otk_rows:
+                return []
 
-            key_id, key_json = otk_row
-
-            self.db_pool.simple_delete_one_txn(
+            self.db_pool.simple_delete_many_txn(
                 txn,
                 table="e2e_one_time_keys_json",
+                column="key_id",
+                values=[otk_row[0] for otk_row in otk_rows],
                 keyvalues={
                     "user_id": user_id,
                     "device_id": device_id,
                     "algorithm": algorithm,
-                    "key_id": key_id,
                 },
             )
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-            return f"{algorithm}:{key_id}", key_json
+            return [
+                (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
+            ]
 
         @trace
         def _claim_e2e_one_time_key_returning(
-            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
-        ) -> Optional[Tuple[str, str]]:
+            txn: LoggingTransaction,
+            user_id: str,
+            device_id: str,
+            algorithm: str,
+            count: int,
+        ) -> List[Tuple[str, str]]:
             """Claim OTK for device for DBs that support RETURNING.
 
             Returns:
@@ -1099,28 +1110,30 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                     AND key_id IN (
                         SELECT key_id FROM e2e_one_time_keys_json
                         WHERE user_id = ? AND device_id = ? AND algorithm = ?
-                        LIMIT 1
+                        LIMIT ?
                     )
                 RETURNING key_id, key_json
             """
 
             txn.execute(
-                sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)
+                sql,
+                (user_id, device_id, algorithm, user_id, device_id, algorithm, count),
             )
-            otk_row = txn.fetchone()
-            if otk_row is None:
-                return None
+            otk_rows = list(txn)
+            if not otk_rows:
+                return []
 
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-            key_id, key_json = otk_row
-            return f"{algorithm}:{key_id}", key_json
+            return [
+                (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
+            ]
 
         results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
-        missing: List[Tuple[str, str, str]] = []
-        for user_id, device_id, algorithm in query_list:
+        missing: List[Tuple[str, str, str, int]] = []
+        for user_id, device_id, algorithm, count in query_list:
             if self.database_engine.supports_returning:
                 # If we support RETURNING clause we can use a single query that
                 # allows us to use autocommit mode.
@@ -1130,21 +1143,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
                 db_autocommit = False
 
-            claim_row = await self.db_pool.runInteraction(
+            claim_rows = await self.db_pool.runInteraction(
                 "claim_e2e_one_time_keys",
                 _claim_e2e_one_time_key,
                 user_id,
                 device_id,
                 algorithm,
+                count,
                 db_autocommit=db_autocommit,
             )
-            if claim_row:
+            if claim_rows:
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
-            else:
-                missing.append((user_id, device_id, algorithm))
+                for claim_row in claim_rows:
+                    device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+            # Did we get enough OTKs?
+            count -= len(claim_rows)
+            if count:
+                missing.append((user_id, device_id, algorithm, count))
 
         return results, missing