diff options
author | Richard van der Hoff <1389908+richvdh@users.noreply.github.com> | 2019-05-22 18:39:33 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-05-22 18:39:33 +0100 |
commit | 1a94de60e89090b8930bcdefe51f47f204f52605 (patch) | |
tree | 3966c0d1f280ef5c0e663d8bab22c616aee5151b /synapse/crypto | |
parent | Merge branch 'master' into develop (diff) | |
download | synapse-1a94de60e89090b8930bcdefe51f47f204f52605.tar.xz |
Run black on synapse.crypto.keyring (#5232)
Diffstat (limited to 'synapse/crypto')
-rw-r--r-- | synapse/crypto/keyring.py | 286 |
1 files changed, 137 insertions, 149 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index d8ba870cca..5cc98542ce 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -56,9 +56,9 @@ from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) -VerifyKeyRequest = namedtuple("VerifyRequest", ( - "server_name", "key_ids", "json_object", "deferred" -)) +VerifyKeyRequest = namedtuple( + "VerifyRequest", ("server_name", "key_ids", "json_object", "deferred") +) """ A request for a verify key to verify a JSON object. @@ -96,9 +96,7 @@ class Keyring(object): def verify_json_for_server(self, server_name, json_object): return logcontext.make_deferred_yieldable( - self.verify_json_objects_for_server( - [(server_name, json_object)] - )[0] + self.verify_json_objects_for_server([(server_name, json_object)])[0] ) def verify_json_objects_for_server(self, server_and_json): @@ -130,18 +128,15 @@ class Keyring(object): if not key_ids: return defer.fail( SynapseError( - 400, - "Not signed by %s" % (server_name,), - Codes.UNAUTHORIZED, + 400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED ) ) - logger.debug("Verifying for %s with key_ids %s", - server_name, key_ids) + logger.debug("Verifying for %s with key_ids %s", server_name, key_ids) # add the key request to the queue, but don't start it off yet. verify_request = VerifyKeyRequest( - server_name, key_ids, json_object, defer.Deferred(), + server_name, key_ids, json_object, defer.Deferred() ) verify_requests.append(verify_request) @@ -179,15 +174,13 @@ class Keyring(object): # any other lookups until we have finished. # The deferreds are called with no logcontext. server_to_deferred = { - rq.server_name: defer.Deferred() - for rq in verify_requests + rq.server_name: defer.Deferred() for rq in verify_requests } # We want to wait for any previous lookups to complete before # proceeding. yield self.wait_for_previous_lookups( - [rq.server_name for rq in verify_requests], - server_to_deferred, + [rq.server_name for rq in verify_requests], server_to_deferred ) # Actually start fetching keys. @@ -216,9 +209,7 @@ class Keyring(object): return res for verify_request in verify_requests: - verify_request.deferred.addBoth( - remove_deferreds, verify_request, - ) + verify_request.deferred.addBoth(remove_deferreds, verify_request) except Exception: logger.exception("Error starting key lookups") @@ -248,7 +239,8 @@ class Keyring(object): break logger.info( "Waiting for existing lookups for %s to complete [loop %i]", - [w[0] for w in wait_on], loop_count, + [w[0] for w in wait_on], + loop_count, ) with PreserveLoggingContext(): yield defer.DeferredList((w[1] for w in wait_on)) @@ -335,13 +327,14 @@ class Keyring(object): with PreserveLoggingContext(): for verify_request in requests_missing_keys: - verify_request.deferred.errback(SynapseError( - 401, - "No key for %s with id %s" % ( - verify_request.server_name, verify_request.key_ids, - ), - Codes.UNAUTHORIZED, - )) + verify_request.deferred.errback( + SynapseError( + 401, + "No key for %s with id %s" + % (verify_request.server_name, verify_request.key_ids), + Codes.UNAUTHORIZED, + ) + ) def on_err(err): with PreserveLoggingContext(): @@ -383,25 +376,26 @@ class Keyring(object): ) defer.returnValue(result) except KeyLookupError as e: - logger.warning( - "Key lookup failed from %r: %s", perspective_name, e, - ) + logger.warning("Key lookup failed from %r: %s", perspective_name, e) except Exception as e: logger.exception( "Unable to get key from %r: %s %s", perspective_name, - type(e).__name__, str(e), + type(e).__name__, + str(e), ) defer.returnValue({}) - results = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background(get_key, p_name, p_keys) - for p_name, p_keys in self.perspective_servers.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(get_key, p_name, p_keys) + for p_name, p_keys in self.perspective_servers.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) union_of_keys = {} for result in results: @@ -412,32 +406,30 @@ class Keyring(object): @defer.inlineCallbacks def get_keys_from_server(self, server_name_and_key_ids): - results = yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.get_server_verify_key_v2_direct, - server_name, - key_ids, - ) - for server_name, key_ids in server_name_and_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.get_server_verify_key_v2_direct, server_name, key_ids + ) + for server_name, key_ids in server_name_and_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) merged = {} for result in results: merged.update(result) - defer.returnValue({ - server_name: keys - for server_name, keys in merged.items() - if keys - }) + defer.returnValue( + {server_name: keys for server_name, keys in merged.items() if keys} + ) @defer.inlineCallbacks - def get_server_verify_key_v2_indirect(self, server_names_and_key_ids, - perspective_name, - perspective_keys): + def get_server_verify_key_v2_indirect( + self, server_names_and_key_ids, perspective_name, perspective_keys + ): # TODO(mark): Set the minimum_valid_until_ts to that needed by # the events being validated or the current time if validating # an incoming request. @@ -448,9 +440,7 @@ class Keyring(object): data={ u"server_keys": { server_name: { - key_id: { - u"minimum_valid_until_ts": 0 - } for key_id in key_ids + key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids } for server_name, key_ids in server_names_and_key_ids } @@ -458,21 +448,19 @@ class Keyring(object): long_retries=True, ) except (NotRetryingDestination, RequestSendFailed) as e: - raise_from( - KeyLookupError("Failed to connect to remote server"), e, - ) + raise_from(KeyLookupError("Failed to connect to remote server"), e) except HttpResponseException as e: - raise_from( - KeyLookupError("Remote server returned an error"), e, - ) + raise_from(KeyLookupError("Remote server returned an error"), e) keys = {} responses = query_response["server_keys"] for response in responses: - if (u"signatures" not in response - or perspective_name not in response[u"signatures"]): + if ( + u"signatures" not in response + or perspective_name not in response[u"signatures"] + ): raise KeyLookupError( "Key response not signed by perspective server" " %r" % (perspective_name,) @@ -482,9 +470,7 @@ class Keyring(object): for key_id in response[u"signatures"][perspective_name]: if key_id in perspective_keys: verify_signed_json( - response, - perspective_name, - perspective_keys[key_id] + response, perspective_name, perspective_keys[key_id] ) verified = True @@ -494,7 +480,7 @@ class Keyring(object): " known key, signed with: %r, known keys: %r", perspective_name, list(response[u"signatures"][perspective_name]), - list(perspective_keys) + list(perspective_keys), ) raise KeyLookupError( "Response not signed with a known key for perspective" @@ -508,18 +494,20 @@ class Keyring(object): keys.setdefault(server_name, {}).update(processed_response) - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store_keys, - server_name=server_name, - from_server=perspective_name, - verify_keys=response_keys, - ) - for server_name, response_keys in keys.items() - ], - consumeErrors=True - ).addErrback(unwrapFirstError)) + yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.store_keys, + server_name=server_name, + from_server=perspective_name, + verify_keys=response_keys, + ) + for server_name, response_keys in keys.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) defer.returnValue(keys) @@ -534,26 +522,26 @@ class Keyring(object): try: response = yield self.client.get_json( destination=server_name, - path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id), + path="/_matrix/key/v2/server/" + + urllib.parse.quote(requested_key_id), ignore_backoff=True, ) except (NotRetryingDestination, RequestSendFailed) as e: - raise_from( - KeyLookupError("Failed to connect to remote server"), e, - ) + raise_from(KeyLookupError("Failed to connect to remote server"), e) except HttpResponseException as e: - raise_from( - KeyLookupError("Remote server returned an error"), e, - ) + raise_from(KeyLookupError("Remote server returned an error"), e) - if (u"signatures" not in response - or server_name not in response[u"signatures"]): + if ( + u"signatures" not in response + or server_name not in response[u"signatures"] + ): raise KeyLookupError("Key response not signed by remote server") if response["server_name"] != server_name: - raise KeyLookupError("Expected a response for server %r not %r" % ( - server_name, response["server_name"] - )) + raise KeyLookupError( + "Expected a response for server %r not %r" + % (server_name, response["server_name"]) + ) response_keys = yield self.process_v2_response( from_server=server_name, @@ -564,16 +552,12 @@ class Keyring(object): keys.update(response_keys) yield self.store_keys( - server_name=server_name, - from_server=server_name, - verify_keys=keys, + server_name=server_name, from_server=server_name, verify_keys=keys ) defer.returnValue({server_name: keys}) @defer.inlineCallbacks - def process_v2_response( - self, from_server, response_json, requested_ids=[], - ): + def process_v2_response(self, from_server, response_json, requested_ids=[]): """Parse a 'Server Keys' structure from the result of a /key request This is used to parse either the entirety of the response from @@ -627,20 +611,13 @@ class Keyring(object): for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: raise KeyLookupError( - "Key response must include verification keys for all" - " signatures" + "Key response must include verification keys for all" " signatures" ) if key_id in verify_keys: - verify_signed_json( - response_json, - server_name, - verify_keys[key_id] - ) + verify_signed_json(response_json, server_name, verify_keys[key_id]) signed_key_json = sign_json( - response_json, - self.config.server_name, - self.config.signing_key[0], + response_json, self.config.server_name, self.config.signing_key[0] ) signed_key_json_bytes = encode_canonical_json(signed_key_json) @@ -653,21 +630,23 @@ class Keyring(object): response_keys.update(verify_keys) response_keys.update(old_verify_keys) - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store.store_server_keys_json, - server_name=server_name, - key_id=key_id, - from_server=from_server, - ts_now_ms=time_now_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=signed_key_json_bytes, - ) - for key_id in updated_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) + yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.store.store_server_keys_json, + server_name=server_name, + key_id=key_id, + from_server=from_server, + ts_now_ms=time_now_ms, + ts_expires_ms=ts_valid_until_ms, + key_json_bytes=signed_key_json_bytes, + ) + for key_id in updated_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) defer.returnValue(response_keys) @@ -681,16 +660,21 @@ class Keyring(object): A deferred that completes when the keys are stored. """ # TODO(markjh): Store whether the keys have expired. - return logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store.store_server_verify_key, - server_name, server_name, key.time_added, key - ) - for key_id, key in verify_keys.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError)) + return logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.store.store_server_verify_key, + server_name, + server_name, + key.time_added, + key, + ) + for key_id, key in verify_keys.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) @defer.inlineCallbacks @@ -713,17 +697,19 @@ def _handle_key_deferred(verify_request): except KeyLookupError as e: logger.warn( "Failed to download keys for %s: %s %s", - server_name, type(e).__name__, str(e), + server_name, + type(e).__name__, + str(e), ) raise SynapseError( - 502, - "Error downloading keys for %s" % (server_name,), - Codes.UNAUTHORIZED, + 502, "Error downloading keys for %s" % (server_name,), Codes.UNAUTHORIZED ) except Exception as e: logger.exception( "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e), + server_name, + type(e).__name__, + str(e), ) raise SynapseError( 401, @@ -733,22 +719,24 @@ def _handle_key_deferred(verify_request): json_object = verify_request.json_object - logger.debug("Got key %s %s:%s for server %s, verifying" % ( - key_id, verify_key.alg, verify_key.version, server_name, - )) + logger.debug( + "Got key %s %s:%s for server %s, verifying" + % (key_id, verify_key.alg, verify_key.version, server_name) + ) try: verify_signed_json(json_object, server_name, verify_key) except SignatureVerifyException as e: logger.debug( "Error verifying signature for %s:%s:%s with key %s: %s", - server_name, verify_key.alg, verify_key.version, + server_name, + verify_key.alg, + verify_key.version, encode_verify_key_base64(verify_key), str(e), ) raise SynapseError( 401, - "Invalid signature for server %s with key %s:%s: %s" % ( - server_name, verify_key.alg, verify_key.version, str(e), - ), + "Invalid signature for server %s with key %s:%s: %s" + % (server_name, verify_key.alg, verify_key.version, str(e)), Codes.UNAUTHORIZED, ) |