diff options
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r-- | synapse/crypto/keyring.py | 315 |
1 files changed, 175 insertions, 140 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 14a27288fd..eaf41b983c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -80,12 +80,13 @@ class KeyLookupError(ValueError): class Keyring(object): def __init__(self, hs): - self.store = hs.get_datastore() self.clock = hs.get_clock() - self.client = hs.get_http_client() - self.config = hs.get_config() - self.perspective_servers = self.config.perspectives - self.hs = hs + + self._key_fetchers = ( + StoreKeyFetcher(hs), + PerspectivesKeyFetcher(hs), + ServerKeyFetcher(hs), + ) # map from server name to Deferred. Has an entry for each server with # an ongoing key download; the Deferred completes once the download @@ -271,13 +272,6 @@ class Keyring(object): verify_requests (list[VerifyKeyRequest]): list of verify requests """ - # These are functions that produce keys given a list of key ids - key_fetch_fns = ( - self.get_keys_from_store, # First try the local store - self.get_keys_from_perspectives, # Then try via perspectives - self.get_keys_from_server, # Then try directly - ) - @defer.inlineCallbacks def do_iterations(): with Measure(self.clock, "get_server_verify_keys"): @@ -288,8 +282,8 @@ class Keyring(object): verify_request.key_ids ) - for fn in key_fetch_fns: - results = yield fn(missing_keys.items()) + for f in self._key_fetchers: + results = yield f.get_keys(missing_keys.items()) # We now need to figure out which verify requests we have keys # for and which we don't @@ -348,8 +342,9 @@ class Keyring(object): run_in_background(do_iterations).addErrback(on_err) - @defer.inlineCallbacks - def get_keys_from_store(self, server_name_and_key_ids): + +class KeyFetcher(object): + def get_keys(self, server_name_and_key_ids): """ Args: server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): @@ -359,6 +354,18 @@ class Keyring(object): Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: map from server_name -> key_id -> FetchKeyResult """ + raise NotImplementedError + + +class StoreKeyFetcher(KeyFetcher): + """KeyFetcher impl which fetches keys from our data store""" + + def __init__(self, hs): + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def get_keys(self, server_name_and_key_ids): + """see KeyFetcher.get_keys""" keys_to_fetch = ( (server_name, key_id) for server_name, key_ids in server_name_and_key_ids @@ -370,8 +377,127 @@ class Keyring(object): keys.setdefault(server_name, {})[key_id] = key defer.returnValue(keys) + +class BaseV2KeyFetcher(object): + def __init__(self, hs): + self.store = hs.get_datastore() + self.config = hs.get_config() + + @defer.inlineCallbacks + def process_v2_response( + self, from_server, response_json, time_added_ms, requested_ids=[] + ): + """Parse a 'Server Keys' structure from the result of a /key request + + This is used to parse either the entirety of the response from + GET /_matrix/key/v2/server, or a single entry from the list returned by + POST /_matrix/key/v2/query. + + Checks that each signature in the response that claims to come from the origin + server is valid. (Does not check that there actually is such a signature, for + some reason.) + + Stores the json in server_keys_json so that it can be used for future responses + to /_matrix/key/v2/query. + + Args: + from_server (str): the name of the server producing this result: either + the origin server for a /_matrix/key/v2/server request, or the notary + for a /_matrix/key/v2/query. + + response_json (dict): the json-decoded Server Keys response object + + time_added_ms (int): the timestamp to record in server_keys_json + + requested_ids (iterable[str]): a list of the key IDs that were requested. + We will store the json for these key ids as well as any that are + actually in the response + + Returns: + Deferred[dict[str, FetchKeyResult]]: map from key_id to result object + """ + ts_valid_until_ms = response_json[u"valid_until_ts"] + + # start by extracting the keys from the response, since they may be required + # to validate the signature on the response. + verify_keys = {} + for key_id, key_data in response_json["verify_keys"].items(): + if is_signing_algorithm_supported(key_id): + key_base64 = key_data["key"] + key_bytes = decode_base64(key_base64) + verify_key = decode_verify_key_bytes(key_id, key_bytes) + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=ts_valid_until_ms + ) + + # TODO: improve this signature checking + server_name = response_json["server_name"] + for key_id in response_json["signatures"].get(server_name, {}): + if key_id not in verify_keys: + raise KeyLookupError( + "Key response must include verification keys for all signatures" + ) + + verify_signed_json( + response_json, server_name, verify_keys[key_id].verify_key + ) + + for key_id, key_data in response_json["old_verify_keys"].items(): + if is_signing_algorithm_supported(key_id): + key_base64 = key_data["key"] + key_bytes = decode_base64(key_base64) + verify_key = decode_verify_key_bytes(key_id, key_bytes) + verify_keys[key_id] = FetchKeyResult( + verify_key=verify_key, valid_until_ts=key_data["expired_ts"] + ) + + # re-sign the json with our own key, so that it is ready if we are asked to + # give it out as a notary server + signed_key_json = sign_json( + response_json, self.config.server_name, self.config.signing_key[0] + ) + + signed_key_json_bytes = encode_canonical_json(signed_key_json) + + # for reasons I don't quite understand, we store this json for the key ids we + # requested, as well as those we got. + updated_key_ids = set(requested_ids) + updated_key_ids.update(verify_keys) + + yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.store.store_server_keys_json, + server_name=server_name, + key_id=key_id, + from_server=from_server, + ts_now_ms=time_added_ms, + ts_expires_ms=ts_valid_until_ms, + key_json_bytes=signed_key_json_bytes, + ) + for key_id in updated_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) + + defer.returnValue(verify_keys) + + +class PerspectivesKeyFetcher(BaseV2KeyFetcher): + """KeyFetcher impl which fetches keys from the "perspectives" servers""" + + def __init__(self, hs): + super(PerspectivesKeyFetcher, self).__init__(hs) + self.clock = hs.get_clock() + self.client = hs.get_http_client() + self.perspective_servers = self.config.perspectives + @defer.inlineCallbacks - def get_keys_from_perspectives(self, server_name_and_key_ids): + def get_keys(self, server_name_and_key_ids): + """see KeyFetcher.get_keys""" + @defer.inlineCallbacks def get_key(perspective_name, perspective_keys): try: @@ -409,28 +535,6 @@ class Keyring(object): defer.returnValue(union_of_keys) @defer.inlineCallbacks - def get_keys_from_server(self, server_name_and_key_ids): - results = yield logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.get_server_verify_key_v2_direct, server_name, key_ids - ) - for server_name, key_ids in server_name_and_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - ) - - merged = {} - for result in results: - merged.update(result) - - defer.returnValue( - {server_name: keys for server_name, keys in merged.items() if keys} - ) - - @defer.inlineCallbacks def get_server_verify_key_v2_indirect( self, server_names_and_key_ids, perspective_name, perspective_keys ): @@ -520,6 +624,38 @@ class Keyring(object): defer.returnValue(keys) + +class ServerKeyFetcher(BaseV2KeyFetcher): + """KeyFetcher impl which fetches keys from the origin servers""" + + def __init__(self, hs): + super(ServerKeyFetcher, self).__init__(hs) + self.clock = hs.get_clock() + self.client = hs.get_http_client() + + @defer.inlineCallbacks + def get_keys(self, server_name_and_key_ids): + """see KeyFetcher.get_keys""" + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.get_server_verify_key_v2_direct, server_name, key_ids + ) + for server_name, key_ids in server_name_and_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) + + merged = {} + for result in results: + merged.update(result) + + defer.returnValue( + {server_name: keys for server_name, keys in merged.items() if keys} + ) + @defer.inlineCallbacks def get_server_verify_key_v2_direct(self, server_name, key_ids): keys = {} # type: dict[str, FetchKeyResult] @@ -568,107 +704,6 @@ class Keyring(object): defer.returnValue({server_name: keys}) - @defer.inlineCallbacks - def process_v2_response( - self, from_server, response_json, time_added_ms, requested_ids=[] - ): - """Parse a 'Server Keys' structure from the result of a /key request - - This is used to parse either the entirety of the response from - GET /_matrix/key/v2/server, or a single entry from the list returned by - POST /_matrix/key/v2/query. - - Checks that each signature in the response that claims to come from the origin - server is valid. (Does not check that there actually is such a signature, for - some reason.) - - Stores the json in server_keys_json so that it can be used for future responses - to /_matrix/key/v2/query. - - Args: - from_server (str): the name of the server producing this result: either - the origin server for a /_matrix/key/v2/server request, or the notary - for a /_matrix/key/v2/query. - - response_json (dict): the json-decoded Server Keys response object - - time_added_ms (int): the timestamp to record in server_keys_json - - requested_ids (iterable[str]): a list of the key IDs that were requested. - We will store the json for these key ids as well as any that are - actually in the response - - Returns: - Deferred[dict[str, FetchKeyResult]]: map from key_id to result object - """ - ts_valid_until_ms = response_json[u"valid_until_ts"] - - # start by extracting the keys from the response, since they may be required - # to validate the signature on the response. - verify_keys = {} - for key_id, key_data in response_json["verify_keys"].items(): - if is_signing_algorithm_supported(key_id): - key_base64 = key_data["key"] - key_bytes = decode_base64(key_base64) - verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_keys[key_id] = FetchKeyResult( - verify_key=verify_key, valid_until_ts=ts_valid_until_ms - ) - - # TODO: improve this signature checking - server_name = response_json["server_name"] - for key_id in response_json["signatures"].get(server_name, {}): - if key_id not in verify_keys: - raise KeyLookupError( - "Key response must include verification keys for all signatures" - ) - - verify_signed_json( - response_json, server_name, verify_keys[key_id].verify_key - ) - - for key_id, key_data in response_json["old_verify_keys"].items(): - if is_signing_algorithm_supported(key_id): - key_base64 = key_data["key"] - key_bytes = decode_base64(key_base64) - verify_key = decode_verify_key_bytes(key_id, key_bytes) - verify_keys[key_id] = FetchKeyResult( - verify_key=verify_key, valid_until_ts=key_data["expired_ts"] - ) - - # re-sign the json with our own key, so that it is ready if we are asked to - # give it out as a notary server - signed_key_json = sign_json( - response_json, self.config.server_name, self.config.signing_key[0] - ) - - signed_key_json_bytes = encode_canonical_json(signed_key_json) - - # for reasons I don't quite understand, we store this json for the key ids we - # requested, as well as those we got. - updated_key_ids = set(requested_ids) - updated_key_ids.update(verify_keys) - - yield logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store.store_server_keys_json, - server_name=server_name, - key_id=key_id, - from_server=from_server, - ts_now_ms=time_added_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=signed_key_json_bytes, - ) - for key_id in updated_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - ) - - defer.returnValue(verify_keys) - @defer.inlineCallbacks def _handle_key_deferred(verify_request): |