diff options
Diffstat (limited to 'synapse/crypto')
-rw-r--r-- | synapse/crypto/keyclient.py | 37 | ||||
-rw-r--r-- | synapse/crypto/keyring.py | 301 |
2 files changed, 314 insertions, 24 deletions
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 74008347c3..4911f0896b 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -25,12 +25,15 @@ import logging logger = logging.getLogger(__name__) +KEY_API_V1 = b"/_matrix/key/v1/" + @defer.inlineCallbacks -def fetch_server_key(server_name, ssl_context_factory): +def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): """Fetch the keys for a remote server.""" factory = SynapseKeyClientFactory() + factory.path = path endpoint = matrix_federation_endpoint( reactor, server_name, ssl_context_factory, timeout=30 ) @@ -42,13 +45,19 @@ def fetch_server_key(server_name, ssl_context_factory): server_response, server_certificate = yield protocol.remote_key defer.returnValue((server_response, server_certificate)) return + except SynapseKeyClientError as e: + logger.exception("Error getting key for %r" % (server_name,)) + if e.status.startswith("4"): + # Don't retry for 4xx responses. + raise IOError("Cannot get key for %r" % server_name) except Exception as e: logger.exception(e) - raise IOError("Cannot get key for %s" % server_name) + raise IOError("Cannot get key for %r" % server_name) class SynapseKeyClientError(Exception): """The key wasn't retrieved from the remote server.""" + status = None pass @@ -66,17 +75,30 @@ class SynapseKeyClientProtocol(HTTPClient): def connectionMade(self): self.host = self.transport.getHost() logger.debug("Connected to %s", self.host) - self.sendCommand(b"GET", b"/_matrix/key/v1/") + self.sendCommand(b"GET", self.path) self.endHeaders() self.timer = reactor.callLater( self.timeout, self.on_timeout ) + def errback(self, error): + if not self.remote_key.called: + self.remote_key.errback(error) + + def callback(self, result): + if not self.remote_key.called: + self.remote_key.callback(result) + def handleStatus(self, version, status, message): if status != b"200": # logger.info("Non-200 response from %s: %s %s", # self.transport.getHost(), status, message) + error = SynapseKeyClientError( + "Non-200 response %r from %r" % (status, self.host) + ) + error.status = status + self.errback(error) self.transport.abortConnection() def handleResponse(self, response_body_bytes): @@ -89,15 +111,18 @@ class SynapseKeyClientProtocol(HTTPClient): return certificate = self.transport.getPeerCertificate() - self.remote_key.callback((json_response, certificate)) + self.callback((json_response, certificate)) self.transport.abortConnection() self.timer.cancel() def on_timeout(self): logger.debug("Timeout waiting for response from %s", self.host) - self.remote_key.errback(IOError("Timeout waiting for response")) + self.errback(IOError("Timeout waiting for response")) self.transport.abortConnection() class SynapseKeyClientFactory(Factory): - protocol = SynapseKeyClientProtocol + def protocol(self): + protocol = SynapseKeyClientProtocol() + protocol.path = self.path + return protocol diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 2b4faee4c1..8709394b97 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -15,7 +15,9 @@ from synapse.crypto.keyclient import fetch_server_key from twisted.internet import defer -from syutil.crypto.jsonsign import verify_signed_json, signature_ids +from syutil.crypto.jsonsign import ( + verify_signed_json, signature_ids, sign_json, encode_canonical_json +) from syutil.crypto.signing_key import ( is_signing_algorithm_supported, decode_verify_key_bytes ) @@ -28,6 +30,8 @@ from synapse.util.async import create_observer from OpenSSL import crypto +import urllib +import hashlib import logging @@ -38,6 +42,9 @@ class Keyring(object): def __init__(self, hs): self.store = hs.get_datastore() self.clock = hs.get_clock() + self.client = hs.get_http_client() + self.config = hs.get_config() + self.perspective_servers = self.config.perspectives self.hs = hs self.key_downloads = {} @@ -89,12 +96,11 @@ class Keyring(object): @defer.inlineCallbacks def get_server_verify_key(self, server_name, key_ids): """Finds a verification key for the server with one of the key ids. + Trys to fetch the key from a trusted perspective server first. Args: - server_name (str): The name of the server to fetch a key for. + server_name(str): The name of the server to fetch a key for. keys_ids (list of str): The key_ids to check for. """ - - # Check the datastore to see if we have one cached. cached = yield self.store.get_server_verify_keys(server_name, key_ids) if cached: @@ -117,7 +123,29 @@ class Keyring(object): @defer.inlineCallbacks def _get_server_verify_key_impl(self, server_name, key_ids): - # Try to fetch the key from the remote server. + keys = None + + perspective_results = [] + for perspective_name, perspective_keys in self.perspective_servers.items(): + @defer.inlineCallbacks + def get_key(): + try: + result = yield self.get_server_verify_key_v2_indirect( + server_name, key_ids, perspective_name, perspective_keys + ) + defer.returnValue(result) + except: + logging.info( + "Unable to getting key %r for %r from %r", + key_ids, server_name, perspective_name, + ) + perspective_results.append(get_key()) + + perspective_results = yield defer.gatherResults(perspective_results) + + for results in perspective_results: + if results is not None: + keys = results limiter = yield get_retry_limiter( server_name, @@ -126,10 +154,234 @@ class Keyring(object): ) with limiter: + if keys is None: + try: + keys = yield self.get_server_verify_key_v2_direct( + server_name, key_ids + ) + except: + pass + + keys = yield self.get_server_verify_key_v1_direct( + server_name, key_ids + ) + + for key_id in key_ids: + if key_id in keys: + defer.returnValue(keys[key_id]) + return + raise ValueError("No verification key found for given key ids") + + @defer.inlineCallbacks + def get_server_verify_key_v2_indirect(self, server_name, key_ids, + perspective_name, + perspective_keys): + limiter = yield get_retry_limiter( + perspective_name, self.clock, self.store + ) + + with limiter: + # TODO(mark): Set the minimum_valid_until_ts to that needed by + # the events being validated or the current time if validating + # an incoming request. + responses = yield self.client.post_json( + destination=perspective_name, + path=b"/_matrix/key/v2/query", + data={ + u"server_keys": { + server_name: { + key_id: { + u"minimum_valid_until_ts": 0 + } for key_id in key_ids + } + } + }, + ) + + keys = {} + + for response in responses: + if (u"signatures" not in response + or perspective_name not in response[u"signatures"]): + raise ValueError( + "Key response not signed by perspective server" + " %r" % (perspective_name,) + ) + + verified = False + for key_id in response[u"signatures"][perspective_name]: + if key_id in perspective_keys: + verify_signed_json( + response, + perspective_name, + perspective_keys[key_id] + ) + verified = True + + if not verified: + logging.info( + "Response from perspective server %r not signed with a" + " known key, signed with: %r, known keys: %r", + perspective_name, + list(response[u"signatures"][perspective_name]), + list(perspective_keys) + ) + raise ValueError( + "Response not signed with a known key for perspective" + " server %r" % (perspective_name,) + ) + + response_keys = yield self.process_v2_response( + server_name, perspective_name, response + ) + + keys.update(response_keys) + + yield self.store_keys( + server_name=server_name, + from_server=perspective_name, + verify_keys=keys, + ) + + defer.returnValue(keys) + + @defer.inlineCallbacks + def get_server_verify_key_v2_direct(self, server_name, key_ids): + + keys = {} + + for requested_key_id in key_ids: + if requested_key_id in keys: + continue + (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_context_factory + server_name, self.hs.tls_context_factory, + path=(b"/_matrix/key/v2/server/%s" % ( + urllib.quote(requested_key_id), + )).encode("ascii"), + ) + + if (u"signatures" not in response + or server_name not in response[u"signatures"]): + raise ValueError("Key response not signed by remote server") + + if "tls_fingerprints" not in response: + raise ValueError("Key response missing TLS fingerprints") + + certificate_bytes = crypto.dump_certificate( + crypto.FILETYPE_ASN1, tls_certificate + ) + sha256_fingerprint = hashlib.sha256(certificate_bytes).digest() + sha256_fingerprint_b64 = encode_base64(sha256_fingerprint) + + response_sha256_fingerprints = set() + for fingerprint in response[u"tls_fingerprints"]: + if u"sha256" in fingerprint: + response_sha256_fingerprints.add(fingerprint[u"sha256"]) + + if sha256_fingerprint_b64 not in response_sha256_fingerprints: + raise ValueError("TLS certificate not allowed by fingerprints") + + response_keys = yield self.process_v2_response( + server_name=server_name, + from_server=server_name, + requested_id=requested_key_id, + response_json=response, + ) + + keys.update(response_keys) + + yield self.store_keys( + server_name=server_name, + from_server=server_name, + verify_keys=keys, + ) + + defer.returnValue(keys) + + @defer.inlineCallbacks + def process_v2_response(self, server_name, from_server, response_json, + requested_id=None): + time_now_ms = self.clock.time_msec() + response_keys = {} + verify_keys = {} + for key_id, key_data in response_json["verify_keys"].items(): + if is_signing_algorithm_supported(key_id): + key_base64 = key_data["key"] + key_bytes = decode_base64(key_base64) + verify_key = decode_verify_key_bytes(key_id, key_bytes) + verify_key.time_added = time_now_ms + verify_keys[key_id] = verify_key + + old_verify_keys = {} + for key_id, key_data in response_json["old_verify_keys"].items(): + if is_signing_algorithm_supported(key_id): + key_base64 = key_data["key"] + key_bytes = decode_base64(key_base64) + verify_key = decode_verify_key_bytes(key_id, key_bytes) + verify_key.expired = key_data["expired_ts"] + verify_key.time_added = time_now_ms + old_verify_keys[key_id] = verify_key + + for key_id in response_json["signatures"][server_name]: + if key_id not in response_json["verify_keys"]: + raise ValueError( + "Key response must include verification keys for all" + " signatures" + ) + if key_id in verify_keys: + verify_signed_json( + response_json, + server_name, + verify_keys[key_id] + ) + + signed_key_json = sign_json( + response_json, + self.config.server_name, + self.config.signing_key[0], + ) + + signed_key_json_bytes = encode_canonical_json(signed_key_json) + ts_valid_until_ms = signed_key_json[u"valid_until_ts"] + + updated_key_ids = set() + if requested_id is not None: + updated_key_ids.add(requested_id) + updated_key_ids.update(verify_keys) + updated_key_ids.update(old_verify_keys) + + response_keys.update(verify_keys) + response_keys.update(old_verify_keys) + + for key_id in updated_key_ids: + yield self.store.store_server_keys_json( + server_name=server_name, + key_id=key_id, + from_server=server_name, + ts_now_ms=time_now_ms, + ts_expires_ms=ts_valid_until_ms, + key_json_bytes=signed_key_json_bytes, ) + defer.returnValue(response_keys) + + raise ValueError("No verification key found for given key ids") + + @defer.inlineCallbacks + def get_server_verify_key_v1_direct(self, server_name, key_ids): + """Finds a verification key for the server with one of the key ids. + Args: + server_name (str): The name of the server to fetch a key for. + keys_ids (list of str): The key_ids to check for. + """ + + # Try to fetch the key from the remote server. + + (response, tls_certificate) = yield fetch_server_key( + server_name, self.hs.tls_context_factory + ) + # Check the response. x509_certificate_bytes = crypto.dump_certificate( @@ -148,11 +400,16 @@ class Keyring(object): if encode_base64(x509_certificate_bytes) != tls_certificate_b64: raise ValueError("TLS certificate doesn't match") + # Cache the result in the datastore. + + time_now_ms = self.clock.time_msec() + verify_keys = {} for key_id, key_base64 in response["verify_keys"].items(): if is_signing_algorithm_supported(key_id): key_bytes = decode_base64(key_base64) verify_key = decode_verify_key_bytes(key_id, key_bytes) + verify_key.time_added = time_now_ms verify_keys[key_id] = verify_key for key_id in response["signatures"][server_name]: @@ -168,10 +425,6 @@ class Keyring(object): verify_keys[key_id] ) - # Cache the result in the datastore. - - time_now_ms = self.clock.time_msec() - yield self.store.store_server_certificate( server_name, server_name, @@ -179,14 +432,26 @@ class Keyring(object): tls_certificate, ) + yield self.store_keys( + server_name=server_name, + from_server=server_name, + verify_keys=verify_keys, + ) + + defer.returnValue(verify_keys) + + @defer.inlineCallbacks + def store_keys(self, server_name, from_server, verify_keys): + """Store a collection of verify keys for a given server + Args: + server_name(str): The name of the server the keys are for. + from_server(str): The server the keys were downloaded from. + verify_keys(dict): A mapping of key_id to VerifyKey. + Returns: + A deferred that completes when the keys are stored. + """ for key_id, key in verify_keys.items(): + # TODO(markjh): Store whether the keys have expired. yield self.store.store_server_verify_key( - server_name, server_name, time_now_ms, key + server_name, server_name, key.time_added, key ) - - for key_id in key_ids: - if key_id in verify_keys: - defer.returnValue(verify_keys[key_id]) - return - - raise ValueError("No verification key found for given key ids") |