summary refs log tree commit diff
path: root/synapse/rest/client/keys.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-04-27 12:57:46 -0400
committerGitHub <noreply@github.com>2023-04-27 12:57:46 -0400
commit57aeeb308b39c4fd455682966eabc9c0fa17c65d (patch)
tree3b59e2a367f7894a2adfca66c6579fe317723a39 /synapse/rest/client/keys.py
parentAdd type hints to schema deltas (#15497) (diff)
downloadsynapse-57aeeb308b39c4fd455682966eabc9c0fa17c65d.tar.xz
Add support for claiming multiple OTKs at once. (#15468)
MSC3983 provides a way to request multiple OTKs at once from appservices,
this extends this concept to the Client-Server API.

Note that this will likely be spit out into a separate MSC, but is currently part of
MSC3983.
Diffstat (limited to 'synapse/rest/client/keys.py')
-rw-r--r--synapse/rest/client/keys.py42
1 files changed, 37 insertions, 5 deletions
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 2a25094109..9bbab5e624 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -16,7 +16,8 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
 
 from synapse.api.errors import InvalidAPICallError, SynapseError
 from synapse.http.server import HttpServer
@@ -289,16 +290,40 @@ class OneTimeKeyServlet(RestServlet):
         await self.auth.get_user_by_req(request, allow_guest=True)
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
+
+        # Generate a count for each algorithm, which is hard-coded to 1.
+        query: Dict[str, Dict[str, Dict[str, int]]] = {}
+        for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+            for device_id, algorithm in one_time_keys.items():
+                query.setdefault(user_id, {})[device_id] = {algorithm: 1}
+
         result = await self.e2e_keys_handler.claim_one_time_keys(
-            body, timeout, always_include_fallback_keys=False
+            query, timeout, always_include_fallback_keys=False
         )
         return 200, result
 
 
 class UnstableOneTimeKeyServlet(RestServlet):
     """
-    Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
-    fallback keys in the response.
+    Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
+    querying for multiple OTKs at once and always includes fallback keys in the
+    response.
+
+    POST /keys/claim HTTP/1.1
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": ["<algorithm>", ...]
+    } } }
+
+    HTTP/1.1 200 OK
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": {
+            "<algorithm>:<key_id>": "<key_base64>"
+    } } } }
+
     """
 
     PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
@@ -313,8 +338,15 @@ class UnstableOneTimeKeyServlet(RestServlet):
         await self.auth.get_user_by_req(request, allow_guest=True)
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
+
+        # Generate a count for each algorithm.
+        query: Dict[str, Dict[str, Dict[str, int]]] = {}
+        for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+            for device_id, algorithms in one_time_keys.items():
+                query.setdefault(user_id, {})[device_id] = Counter(algorithms)
+
         result = await self.e2e_keys_handler.claim_one_time_keys(
-            body, timeout, always_include_fallback_keys=True
+            query, timeout, always_include_fallback_keys=True
         )
         return 200, result