From 949c2c54352f5a1fe2d8de39c4ddebc1f1e13aac Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 12 Sep 2016 18:17:09 +0100 Subject: Add a timeout parameter for end2end key queries. Add a timeout parameter for controlling how long synapse will wait for responses from remote servers. For servers that fail include how they failed to make it easier to debug. Fetch keys from different servers in parallel rather than in series. Set the default timeout to 10s. --- synapse/rest/client/v2_alpha/keys.py | 77 ++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 26 deletions(-) (limited to 'synapse/rest') diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index c5ff16adf3..8f05727652 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -19,11 +19,12 @@ import simplejson as json from canonicaljson import encode_canonical_json from twisted.internet import defer -import synapse.api.errors -import synapse.server -import synapse.types -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +from synapse.api.errors import SynapseError, CodeMessageException +from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, parse_integer +) +from synapse.types import get_domain_from_id +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -88,7 +89,7 @@ class KeyUploadServlet(RestServlet): device_id = requester.device_id if device_id is None: - raise synapse.api.errors.SynapseError( + raise SynapseError( 400, "To upload keys, you must pass device_id when authenticating" ) @@ -195,18 +196,21 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body) + result = yield self.e2e_keys_handler.query_devices(body, timeout) defer.returnValue(result) @defer.inlineCallbacks def on_GET(self, request, user_id, device_id): requester = yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] result = yield self.e2e_keys_handler.query_devices( - {"device_keys": {user_id: device_ids}} + {"device_keys": {user_id: device_ids}}, + timeout, ) defer.returnValue(result) @@ -244,39 +248,43 @@ class OneTimeKeyServlet(RestServlet): self.auth = hs.get_auth() self.clock = hs.get_clock() self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.is_mine_id = hs.is_mine_id @defer.inlineCallbacks def on_GET(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) result = yield self.handle_request( - {"one_time_keys": {user_id: {device_id: algorithm}}} + {"one_time_keys": {user_id: {device_id: algorithm}}}, + timeout, ) defer.returnValue(result) @defer.inlineCallbacks def on_POST(self, request, user_id, device_id, algorithm): yield self.auth.get_user_by_req(request) + timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.handle_request(body) + result = yield self.handle_request(body, timeout) defer.returnValue(result) @defer.inlineCallbacks - def handle_request(self, body): + def handle_request(self, body, timeout): local_query = [] remote_queries = {} + for user_id, device_keys in body.get("one_time_keys", {}).items(): - user = UserID.from_string(user_id) - if self.is_mine(user): + if self.is_mine_id(user_id): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: - remote_queries.setdefault(user.domain, {})[user_id] = ( - device_keys - ) + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_keys + results = yield self.store.claim_e2e_one_time_keys(local_query) json_result = {} + failures = {} for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): @@ -284,15 +292,32 @@ class OneTimeKeyServlet(RestServlet): key_id: json.loads(json_bytes) } - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.claim_client_keys( - destination, {"one_time_keys": device_keys} - ) - for user_id, keys in remote_result["one_time_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - - defer.returnValue((200, {"one_time_keys": json_result})) + @defer.inlineCallbacks + def claim_client_keys(destination): + device_keys = remote_queries[destination] + try: + remote_result = yield self.federation.claim_client_keys( + destination, + {"one_time_keys": device_keys}, + timeout=timeout + ) + for user_id, keys in remote_result["one_time_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(claim_client_keys)(destination) + for destination in remote_queries + ])) + + defer.returnValue((200, { + "one_time_keys": json_result, + "failures": failures + })) def register_servlets(hs, http_server): -- cgit 1.4.1