diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 2a25094109..9bbab5e624 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -16,7 +16,8 @@
import logging
import re
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.server import HttpServer
@@ -289,16 +290,40 @@ class OneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
+
+ # Generate a count for each algorithm, which is hard-coded to 1.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+ for device_id, algorithm in one_time_keys.items():
+ query.setdefault(user_id, {})[device_id] = {algorithm: 1}
+
result = await self.e2e_keys_handler.claim_one_time_keys(
- body, timeout, always_include_fallback_keys=False
+ query, timeout, always_include_fallback_keys=False
)
return 200, result
class UnstableOneTimeKeyServlet(RestServlet):
"""
- Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
- fallback keys in the response.
+ Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
+ querying for multiple OTKs at once and always includes fallback keys in the
+ response.
+
+ POST /keys/claim HTTP/1.1
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": ["<algorithm>", ...]
+ } } }
+
+ HTTP/1.1 200 OK
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ } } } }
+
"""
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
@@ -313,8 +338,15 @@ class UnstableOneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
+
+ # Generate a count for each algorithm.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+ for device_id, algorithms in one_time_keys.items():
+ query.setdefault(user_id, {})[device_id] = Counter(algorithms)
+
result = await self.e2e_keys_handler.claim_one_time_keys(
- body, timeout, always_include_fallback_keys=True
+ query, timeout, always_include_fallback_keys=True
)
return 200, result
|