diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index c5ff16adf3..8f05727652 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -19,11 +19,12 @@ import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
-import synapse.api.errors
-import synapse.server
-import synapse.types
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.http.servlet import (
+ RestServlet, parse_json_object_from_request, parse_integer
+)
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet):
device_id = requester.device_id
if device_id is None:
- raise synapse.api.errors.SynapseError(
+ raise SynapseError(
400,
"To upload keys, you must pass device_id when authenticating"
)
@@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.query_devices(body)
+ result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue(result)
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
requester = yield self.auth.get_user_by_req(request)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.e2e_keys_handler.query_devices(
- {"device_keys": {user_id: device_ids}}
+ {"device_keys": {user_id: device_ids}},
+ timeout,
)
defer.returnValue(result)
@@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
- self.is_mine = hs.is_mine
+ self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
result = yield self.handle_request(
- {"one_time_keys": {user_id: {device_id: algorithm}}}
+ {"one_time_keys": {user_id: {device_id: algorithm}}},
+ timeout,
)
defer.returnValue(result)
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.handle_request(body)
+ result = yield self.handle_request(body, timeout)
defer.returnValue(result)
@defer.inlineCallbacks
- def handle_request(self, body):
+ def handle_request(self, body, timeout):
local_query = []
remote_queries = {}
+
for user_id, device_keys in body.get("one_time_keys", {}).items():
- user = UserID.from_string(user_id)
- if self.is_mine(user):
+ if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
- remote_queries.setdefault(user.domain, {})[user_id] = (
- device_keys
- )
+ domain = get_domain_from_id(user_id)
+ remote_queries.setdefault(domain, {})[user_id] = device_keys
+
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
+ failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
@@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet):
key_id: json.loads(json_bytes)
}
- for destination, device_keys in remote_queries.items():
- remote_result = yield self.federation.claim_client_keys(
- destination, {"one_time_keys": device_keys}
- )
- for user_id, keys in remote_result["one_time_keys"].items():
- if user_id in device_keys:
- json_result[user_id] = keys
-
- defer.returnValue((200, {"one_time_keys": json_result}))
+ @defer.inlineCallbacks
+ def claim_client_keys(destination):
+ device_keys = remote_queries[destination]
+ try:
+ remote_result = yield self.federation.claim_client_keys(
+ destination,
+ {"one_time_keys": device_keys},
+ timeout=timeout
+ )
+ for user_id, keys in remote_result["one_time_keys"].items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
+ except CodeMessageException as e:
+ failures[destination] = {
+ "status": e.code, "message": e.message
+ }
+
+ yield preserve_context_over_deferred(defer.gatherResults([
+ preserve_fn(claim_client_keys)(destination)
+ for destination in remote_queries
+ ]))
+
+ defer.returnValue((200, {
+ "one_time_keys": json_result,
+ "failures": failures
+ }))
def register_servlets(hs, http_server):
|