diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 6209b79b01..9bbab5e624 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,7 +15,9 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+import re
+from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.server import HttpServer
@@ -288,7 +290,64 @@ class OneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
+
+ # Generate a count for each algorithm, which is hard-coded to 1.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+ for device_id, algorithm in one_time_keys.items():
+ query.setdefault(user_id, {})[device_id] = {algorithm: 1}
+
+ result = await self.e2e_keys_handler.claim_one_time_keys(
+ query, timeout, always_include_fallback_keys=False
+ )
+ return 200, result
+
+
+class UnstableOneTimeKeyServlet(RestServlet):
+ """
+ Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
+ querying for multiple OTKs at once and always includes fallback keys in the
+ response.
+
+ POST /keys/claim HTTP/1.1
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": ["<algorithm>", ...]
+ } } }
+
+ HTTP/1.1 200 OK
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ } } } }
+
+ """
+
+ PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
+ CATEGORY = "Encryption requests"
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.auth.get_user_by_req(request, allow_guest=True)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
+ body = parse_json_object_from_request(request)
+
+ # Generate a count for each algorithm.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+ for device_id, algorithms in one_time_keys.items():
+ query.setdefault(user_id, {})[device_id] = Counter(algorithms)
+
+ result = await self.e2e_keys_handler.claim_one_time_keys(
+ query, timeout, always_include_fallback_keys=True
+ )
return 200, result
@@ -394,6 +453,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
+ if hs.config.experimental.msc3983_appservice_otk_claims:
+ UnstableOneTimeKeyServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)
|