summary refs log tree commit diff
path: root/synapse/rest/client/keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/keys.py')
-rw-r--r--synapse/rest/client/keys.py42
1 files changed, 37 insertions, 5 deletions
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 2a25094109..9bbab5e624 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -16,7 +16,8 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
 
 from synapse.api.errors import InvalidAPICallError, SynapseError
 from synapse.http.server import HttpServer
@@ -289,16 +290,40 @@ class OneTimeKeyServlet(RestServlet):
         await self.auth.get_user_by_req(request, allow_guest=True)
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
+
+        # Generate a count for each algorithm, which is hard-coded to 1.
+        query: Dict[str, Dict[str, Dict[str, int]]] = {}
+        for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+            for device_id, algorithm in one_time_keys.items():
+                query.setdefault(user_id, {})[device_id] = {algorithm: 1}
+
         result = await self.e2e_keys_handler.claim_one_time_keys(
-            body, timeout, always_include_fallback_keys=False
+            query, timeout, always_include_fallback_keys=False
         )
         return 200, result
 
 
 class UnstableOneTimeKeyServlet(RestServlet):
     """
-    Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
-    fallback keys in the response.
+    Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
+    querying for multiple OTKs at once and always includes fallback keys in the
+    response.
+
+    POST /keys/claim HTTP/1.1
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": ["<algorithm>", ...]
+    } } }
+
+    HTTP/1.1 200 OK
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": {
+            "<algorithm>:<key_id>": "<key_base64>"
+    } } } }
+
     """
 
     PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
@@ -313,8 +338,15 @@ class UnstableOneTimeKeyServlet(RestServlet):
         await self.auth.get_user_by_req(request, allow_guest=True)
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
+
+        # Generate a count for each algorithm.
+        query: Dict[str, Dict[str, Dict[str, int]]] = {}
+        for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+            for device_id, algorithms in one_time_keys.items():
+                query.setdefault(user_id, {})[device_id] = Counter(algorithms)
+
         result = await self.e2e_keys_handler.claim_one_time_keys(
-            body, timeout, always_include_fallback_keys=True
+            query, timeout, always_include_fallback_keys=True
         )
         return 200, result