diff options
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r-- | synapse/handlers/e2e_keys.py | 101 |
1 files changed, 55 insertions, 46 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 668a90e495..5816bf8b4f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ujson as json import logging -from canonicaljson import encode_canonical_json +from six import iteritems + +from canonicaljson import encode_canonical_json, json + from twisted.internet import defer -from synapse.api.errors import SynapseError, CodeMessageException -from synapse.types import get_domain_from_id -from synapse.util.logcontext import preserve_fn, make_deferred_yieldable +from synapse.api.errors import CodeMessageException, FederationDeniedError, SynapseError +from synapse.types import UserID, get_domain_from_id +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -30,15 +33,15 @@ logger = logging.getLogger(__name__) class E2eKeysHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() - self.is_mine_id = hs.is_mine_id + self.is_mine = hs.is_mine self.clock = hs.get_clock() # doesn't really work as part of the generic query API, because the # query request requires an object POST, but we abuse the # "query handler" interface. - self.federation.register_query_handler( + hs.get_federation_registry().register_query_handler( "client_keys", self.on_federation_query_client_keys ) @@ -70,12 +73,13 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_ids in device_keys_query.items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): local_query[user_id] = device_ids else: remote_queries[user_id] = device_ids - # Firt get local devices. + # First get local devices. failures = {} results = {} if local_query: @@ -88,7 +92,7 @@ class E2eKeysHandler(object): remote_queries_not_in_cache = {} if remote_queries: query_list = [] - for user_id, device_ids in remote_queries.iteritems(): + for user_id, device_ids in iteritems(remote_queries): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) else: @@ -99,9 +103,9 @@ class E2eKeysHandler(object): query_list ) ) - for user_id, devices in remote_results.iteritems(): + for user_id, devices in iteritems(remote_results): user_devices = results.setdefault(user_id, {}) - for device_id, device in devices.iteritems(): + for device_id, device in iteritems(devices): keys = device.get("keys", None) device_display_name = device.get("device_display_name", None) if keys: @@ -131,24 +135,13 @@ class E2eKeysHandler(object): if user_id in destination_query: results[user_id] = keys - except CodeMessageException as e: - failures[destination] = { - "status": e.code, "message": e.message - } - except NotRetryingDestination as e: - failures[destination] = { - "status": 503, "message": "Not ready for retry", - } except Exception as e: - # include ConnectionRefused and other errors - failures[destination] = { - "status": 503, "message": e.message - } + failures[destination] = _exception_to_failure(e) yield make_deferred_yieldable(defer.gatherResults([ - preserve_fn(do_remote_query)(destination) + run_in_background(do_remote_query, destination) for destination in remote_queries_not_in_cache - ])) + ], consumeErrors=True)) defer.returnValue({ "device_keys": results, "failures": failures, @@ -170,7 +163,8 @@ class E2eKeysHandler(object): result_dict = {} for user_id, device_ids in query.items(): - if not self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if not self.is_mine(UserID.from_string(user_id)): logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") @@ -213,7 +207,8 @@ class E2eKeysHandler(object): remote_queries = {} for user_id, device_keys in query.get("one_time_keys", {}).items(): - if self.is_mine_id(user_id): + # we use UserID.from_string to catch invalid user ids + if self.is_mine(UserID.from_string(user_id)): for device_id, algorithm in device_keys.items(): local_query.append((user_id, device_id, algorithm)) else: @@ -243,32 +238,21 @@ class E2eKeysHandler(object): for user_id, keys in remote_result["one_time_keys"].items(): if user_id in device_keys: json_result[user_id] = keys - except CodeMessageException as e: - failures[destination] = { - "status": e.code, "message": e.message - } - except NotRetryingDestination as e: - failures[destination] = { - "status": 503, "message": "Not ready for retry", - } except Exception as e: - # include ConnectionRefused and other errors - failures[destination] = { - "status": 503, "message": e.message - } + failures[destination] = _exception_to_failure(e) yield make_deferred_yieldable(defer.gatherResults([ - preserve_fn(claim_client_keys)(destination) + run_in_background(claim_client_keys, destination) for destination in remote_queries - ])) + ], consumeErrors=True)) logger.info( "Claimed one-time-keys: %s", ",".join(( "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in json_result.iteritems() - for device_id, device_keys in user_keys.iteritems() - for key_id, _ in device_keys.iteritems() + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) )), ) @@ -353,6 +337,31 @@ class E2eKeysHandler(object): ) +def _exception_to_failure(e): + if isinstance(e, CodeMessageException): + return { + "status": e.code, "message": e.message, + } + + if isinstance(e, NotRetryingDestination): + return { + "status": 503, "message": "Not ready for retry", + } + + if isinstance(e, FederationDeniedError): + return { + "status": 403, "message": "Federation Denied", + } + + # include ConnectionRefused and other errors + # + # Note that some Exceptions (notably twisted's ResponseFailed etc) don't + # give a string for e.message, which json then fails to serialize. + return { + "status": 503, "message": str(e.message), + } + + def _one_time_keys_match(old_key_json, new_key): old_key = json.loads(old_key_json) |