diff options
Diffstat (limited to 'synapse/crypto')
-rw-r--r-- | synapse/crypto/keyclient.py | 17 | ||||
-rw-r--r-- | synapse/crypto/keyring.py | 64 |
2 files changed, 49 insertions, 32 deletions
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 4911f0896b..24f15f3154 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient from twisted.internet.protocol import Factory from twisted.internet import defer, reactor from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + preserve_context_over_fn, preserve_context_over_deferred +) import simplejson as json import logging @@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): for i in range(5): try: - with PreserveLoggingContext(): - protocol = yield endpoint.connect(factory) - server_response, server_certificate = yield protocol.remote_key - defer.returnValue((server_response, server_certificate)) - return + protocol = yield preserve_context_over_fn( + endpoint.connect, factory + ) + server_response, server_certificate = yield preserve_context_over_deferred( + protocol.remote_key + ) + defer.returnValue((server_response, server_certificate)) + return except SynapseKeyClientError as e: logger.exception("Error getting key for %r" % (server_name,)) if e.status.startswith("4"): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 8709394b97..aff69c5f83 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter -from synapse.util.async import create_observer +from synapse.util.async import ObservableDeferred from OpenSSL import crypto @@ -111,6 +111,10 @@ class Keyring(object): if download is None: download = self._get_server_verify_key_impl(server_name, key_ids) + download = ObservableDeferred( + download, + consumeErrors=True + ) self.key_downloads[server_name] = download @download.addBoth @@ -118,30 +122,31 @@ class Keyring(object): del self.key_downloads[server_name] return ret - r = yield create_observer(download) + r = yield download.observe() defer.returnValue(r) @defer.inlineCallbacks def _get_server_verify_key_impl(self, server_name, key_ids): keys = None - perspective_results = [] - for perspective_name, perspective_keys in self.perspective_servers.items(): - @defer.inlineCallbacks - def get_key(): - try: - result = yield self.get_server_verify_key_v2_indirect( - server_name, key_ids, perspective_name, perspective_keys - ) - defer.returnValue(result) - except: - logging.info( - "Unable to getting key %r for %r from %r", - key_ids, server_name, perspective_name, - ) - perspective_results.append(get_key()) + @defer.inlineCallbacks + def get_key(perspective_name, perspective_keys): + try: + result = yield self.get_server_verify_key_v2_indirect( + server_name, key_ids, perspective_name, perspective_keys + ) + defer.returnValue(result) + except Exception as e: + logging.info( + "Unable to getting key %r for %r from %r: %s %s", + key_ids, server_name, perspective_name, + type(e).__name__, str(e.message), + ) - perspective_results = yield defer.gatherResults(perspective_results) + perspective_results = yield defer.gatherResults([ + get_key(p_name, p_keys) + for p_name, p_keys in self.perspective_servers.items() + ]) for results in perspective_results: if results is not None: @@ -154,17 +159,22 @@ class Keyring(object): ) with limiter: - if keys is None: + if not keys: try: keys = yield self.get_server_verify_key_v2_direct( server_name, key_ids ) - except: - pass + except Exception as e: + logging.info( + "Unable to getting key %r for %r directly: %s %s", + key_ids, server_name, + type(e).__name__, str(e.message), + ) - keys = yield self.get_server_verify_key_v1_direct( - server_name, key_ids - ) + if not keys: + keys = yield self.get_server_verify_key_v1_direct( + server_name, key_ids + ) for key_id in key_ids: if key_id in keys: @@ -184,7 +194,7 @@ class Keyring(object): # TODO(mark): Set the minimum_valid_until_ts to that needed by # the events being validated or the current time if validating # an incoming request. - responses = yield self.client.post_json( + query_response = yield self.client.post_json( destination=perspective_name, path=b"/_matrix/key/v2/query", data={ @@ -200,6 +210,8 @@ class Keyring(object): keys = {} + responses = query_response["server_keys"] + for response in responses: if (u"signatures" not in response or perspective_name not in response[u"signatures"]): @@ -323,7 +335,7 @@ class Keyring(object): verify_key.time_added = time_now_ms old_verify_keys[key_id] = verify_key - for key_id in response_json["signatures"][server_name]: + for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: raise ValueError( "Key response must include verification keys for all" |