From b5f55a1d85386834b18828b196e1859a4410d98a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Jun 2015 09:52:24 +0100 Subject: Implement bulk verify_signed_json API --- synapse/crypto/keyring.py | 426 +++++++++++++++++++++++++++++++--------------- 1 file changed, 292 insertions(+), 134 deletions(-) (limited to 'synapse/crypto') diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index aff69c5f83..873c9b40fa 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64 from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter - -from synapse.util.async import ObservableDeferred +from synapse.util import unwrapFirstError from OpenSSL import crypto +from collections import namedtuple import urllib import hashlib import logging @@ -38,6 +38,9 @@ import logging logger = logging.getLogger(__name__) +KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids")) + + class Keyring(object): def __init__(self, hs): self.store = hs.get_datastore() @@ -49,141 +52,274 @@ class Keyring(object): self.key_downloads = {} - @defer.inlineCallbacks def verify_json_for_server(self, server_name, json_object): - logger.debug("Verifying for %s", server_name) - key_ids = signature_ids(json_object, server_name) - if not key_ids: - raise SynapseError( - 400, - "Not signed with a supported algorithm", - Codes.UNAUTHORIZED, - ) - try: - verify_key = yield self.get_server_verify_key(server_name, key_ids) - except IOError as e: - logger.warn( - "Got IOError when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 502, - "Error downloading keys for %s" % (server_name,), - Codes.UNAUTHORIZED, - ) - except Exception as e: - logger.warn( - "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 401, - "No key for %s with id %s" % (server_name, key_ids), - Codes.UNAUTHORIZED, - ) + return self.verify_json_objects_for_server( + [(server_name, json_object)] + )[0] - try: - verify_signed_json(json_object, server_name, verify_key) - except: - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s" % ( - server_name, verify_key.alg, verify_key.version - ), - Codes.UNAUTHORIZED, - ) + def verify_json_objects_for_server(self, server_and_json): + """Bulk verfies signatures of json objects, bulk fetching keys as + necessary. - @defer.inlineCallbacks - def get_server_verify_key(self, server_name, key_ids): - """Finds a verification key for the server with one of the key ids. - Trys to fetch the key from a trusted perspective server first. Args: - server_name(str): The name of the server to fetch a key for. - keys_ids (list of str): The key_ids to check for. + server_and_json (list): List of pairs of (server_name, json_object) + + Returns: + list of deferreds indicating success or failure to verify each + json object's signature for the given server_name. """ - cached = yield self.store.get_server_verify_keys(server_name, key_ids) + group_id_to_json = {} + group_id_to_group = {} + group_ids = [] + + next_group_id = 0 + deferreds = {} + + for server_name, json_object in server_and_json: + logger.debug("Verifying for %s", server_name) + group_id = next_group_id + next_group_id += 1 + group_ids.append(group_id) + + key_ids = signature_ids(json_object, server_name) + if not key_ids: + deferreds[group_id] = defer.fail(SynapseError( + 400, + "Not signed with a supported algorithm", + Codes.UNAUTHORIZED, + )) + + group = KeyGroup(server_name, group_id, key_ids) - if cached: - defer.returnValue(cached[0]) - return + group_id_to_group[group_id] = group + group_id_to_json[group_id] = json_object - download = self.key_downloads.get(server_name) + @defer.inlineCallbacks + def handle_key_deferred(group, deferred): + server_name = group.server_name + try: + _, _, key_id, verify_key = yield deferred + except IOError as e: + logger.warn( + "Got IOError when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) + raise SynapseError( + 502, + "Error downloading keys for %s" % (server_name,), + Codes.UNAUTHORIZED, + ) + except Exception as e: + logger.exception( + "Got Exception when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) + raise SynapseError( + 401, + "No key for %s with id %s" % (server_name, key_ids), + Codes.UNAUTHORIZED, + ) + + json_object = group_id_to_json[group.group_id] - if download is None: - download = self._get_server_verify_key_impl(server_name, key_ids) - download = ObservableDeferred( - download, - consumeErrors=True + try: + verify_signed_json(json_object, server_name, verify_key) + except: + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s" % ( + server_name, verify_key.alg, verify_key.version + ), + Codes.UNAUTHORIZED, + ) + + deferreds.update(self.get_server_verify_keys( + group_id_to_group + )) + + return [ + handle_key_deferred( + group_id_to_group[g_id], + deferreds[g_id], ) - self.key_downloads[server_name] = download + for g_id in group_ids + ] - @download.addBoth - def callback(ret): - del self.key_downloads[server_name] - return ret + def get_server_verify_keys(self, group_id_to_group): + """Takes a dict of KeyGroups and tries to find at least one key for + each group. + """ + + # These are functions that produce keys given a list of key ids + key_fetch_fns = ( + self.get_keys_from_store, # First try the local store + self.get_keys_from_perspectives, # Then try via perspectives + self.get_keys_from_server, # Then try directly + ) + + group_deferreds = { + group_id: defer.Deferred() + for group_id in group_id_to_group + } + + @defer.inlineCallbacks + def do_iterations(): + merged_results = {} + + missing_keys = { + group.server_name: key_id + for group in group_id_to_group.values() + for key_id in group.key_ids + } + + for fn in key_fetch_fns: + results = yield fn(missing_keys.items()) + merged_results.update(results) + + # We now need to figure out which groups we have keys for + # and which we don't + missing_groups = {} + for group in group_id_to_group.values(): + for key_id in group.key_ids: + if key_id in merged_results[group.server_name]: + group_deferreds.pop(group.group_id).callback(( + group.group_id, + group.server_name, + key_id, + merged_results[group.server_name][key_id], + )) + break + else: + missing_groups.setdefault( + group.server_name, [] + ).append(group) + + if not missing_groups: + break + + missing_keys = { + server_name: set( + key_id for group in groups for key_id in group.key_ids + ) + for server_name, groups in missing_groups.items() + } - r = yield download.observe() - defer.returnValue(r) + for group in missing_groups.values(): + group_deferreds.pop(group.group_id).errback(SynapseError( + 401, + "No key for %s with id %s" % ( + group.server_name, group.key_ids, + ), + Codes.UNAUTHORIZED, + )) + + def on_err(err): + for deferred in group_deferreds.values(): + deferred.errback(err) + group_deferreds.clear() + + do_iterations().addErrback(on_err) + + return group_deferreds @defer.inlineCallbacks - def _get_server_verify_key_impl(self, server_name, key_ids): - keys = None + def get_keys_from_store(self, server_name_and_key_ids): + res = yield defer.gatherResults( + [ + self.store.get_server_verify_keys(server_name, key_ids) + for server_name, key_ids in server_name_and_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + defer.returnValue(dict(zip( + [server_name for server_name, _ in server_name_and_key_ids], + res + ))) + @defer.inlineCallbacks + def get_keys_from_perspectives(self, server_name_and_key_ids): @defer.inlineCallbacks def get_key(perspective_name, perspective_keys): try: result = yield self.get_server_verify_key_v2_indirect( - server_name, key_ids, perspective_name, perspective_keys + server_name_and_key_ids, perspective_name, perspective_keys ) defer.returnValue(result) except Exception as e: - logging.info( - "Unable to getting key %r for %r from %r: %s %s", - key_ids, server_name, perspective_name, + logger.exception( + "Unable to get key from %r: %s %s", + perspective_name, type(e).__name__, str(e.message), ) + defer.returnValue({}) - perspective_results = yield defer.gatherResults([ - get_key(p_name, p_keys) - for p_name, p_keys in self.perspective_servers.items() - ]) + results = yield defer.gatherResults( + [ + get_key(p_name, p_keys) + for p_name, p_keys in self.perspective_servers.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) - for results in perspective_results: - if results is not None: - keys = results + union_of_keys = {} + for result in results: + for server_name, keys in result.items(): + union_of_keys.setdefault(server_name, {}).update(keys) - limiter = yield get_retry_limiter( - server_name, - self.clock, - self.store, - ) + defer.returnValue(union_of_keys) - with limiter: - if not keys: + @defer.inlineCallbacks + def get_keys_from_server(self, server_name_and_key_ids): + @defer.inlineCallbacks + def get_key(server_name, key_ids): + limiter = yield get_retry_limiter( + server_name, + self.clock, + self.store, + ) + with limiter: + keys = None try: keys = yield self.get_server_verify_key_v2_direct( server_name, key_ids ) except Exception as e: - logging.info( + logger.info( "Unable to getting key %r for %r directly: %s %s", key_ids, server_name, type(e).__name__, str(e.message), ) - if not keys: - keys = yield self.get_server_verify_key_v1_direct( - server_name, key_ids - ) + if not keys: + keys = yield self.get_server_verify_key_v1_direct( + server_name, key_ids + ) + + keys = {server_name: keys} + + defer.returnValue(keys) + + results = yield defer.gatherResults( + [ + get_key(server_name, key_ids) + for server_name, key_ids in server_name_and_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) - for key_id in key_ids: - if key_id in keys: - defer.returnValue(keys[key_id]) - return - raise ValueError("No verification key found for given key ids") + merged = {} + for result in results: + merged.update(result) + + defer.returnValue({ + server_name: keys + for server_name, keys in merged.items() + if keys + }) @defer.inlineCallbacks - def get_server_verify_key_v2_indirect(self, server_name, key_ids, + def get_server_verify_key_v2_indirect(self, server_names_and_key_ids, perspective_name, perspective_keys): limiter = yield get_retry_limiter( @@ -204,6 +340,7 @@ class Keyring(object): u"minimum_valid_until_ts": 0 } for key_id in key_ids } + for server_name, key_ids in server_names_and_key_ids } }, ) @@ -243,23 +380,29 @@ class Keyring(object): " server %r" % (perspective_name,) ) - response_keys = yield self.process_v2_response( - server_name, perspective_name, response + processed_response = yield self.process_v2_response( + perspective_name, response ) - keys.update(response_keys) + for server_name, response_keys in processed_response.items(): + keys.setdefault(server_name, {}).update(response_keys) - yield self.store_keys( - server_name=server_name, - from_server=perspective_name, - verify_keys=keys, - ) + yield defer.gatherResults( + [ + self.store_keys( + server_name=server_name, + from_server=perspective_name, + verify_keys=response_keys, + ) + for server_name, response_keys in keys.items() + ], + consumeErrors=True + ).addErrback(unwrapFirstError) defer.returnValue(keys) @defer.inlineCallbacks def get_server_verify_key_v2_direct(self, server_name, key_ids): - keys = {} for requested_key_id in key_ids: @@ -295,25 +438,30 @@ class Keyring(object): raise ValueError("TLS certificate not allowed by fingerprints") response_keys = yield self.process_v2_response( - server_name=server_name, from_server=server_name, - requested_id=requested_key_id, + requested_ids=[requested_key_id], response_json=response, ) keys.update(response_keys) - yield self.store_keys( - server_name=server_name, - from_server=server_name, - verify_keys=keys, - ) + yield defer.gatherResults( + [ + self.store_keys( + server_name=key_server_name, + from_server=server_name, + verify_keys=verify_keys, + ) + for key_server_name, verify_keys in keys.items() + ], + consumeErrors=True + ).addErrback(unwrapFirstError) defer.returnValue(keys) @defer.inlineCallbacks - def process_v2_response(self, server_name, from_server, response_json, - requested_id=None): + def process_v2_response(self, from_server, response_json, + requested_ids=[]): time_now_ms = self.clock.time_msec() response_keys = {} verify_keys = {} @@ -335,6 +483,8 @@ class Keyring(object): verify_key.time_added = time_now_ms old_verify_keys[key_id] = verify_key + results = {} + server_name = response_json["server_name"] for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: raise ValueError( @@ -357,28 +507,31 @@ class Keyring(object): signed_key_json_bytes = encode_canonical_json(signed_key_json) ts_valid_until_ms = signed_key_json[u"valid_until_ts"] - updated_key_ids = set() - if requested_id is not None: - updated_key_ids.add(requested_id) + updated_key_ids = set(requested_ids) updated_key_ids.update(verify_keys) updated_key_ids.update(old_verify_keys) response_keys.update(verify_keys) response_keys.update(old_verify_keys) - for key_id in updated_key_ids: - yield self.store.store_server_keys_json( - server_name=server_name, - key_id=key_id, - from_server=server_name, - ts_now_ms=time_now_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=signed_key_json_bytes, - ) + yield defer.gatherResults( + [ + self.store.store_server_keys_json( + server_name=server_name, + key_id=key_id, + from_server=server_name, + ts_now_ms=time_now_ms, + ts_expires_ms=ts_valid_until_ms, + key_json_bytes=signed_key_json_bytes, + ) + for key_id in updated_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) - defer.returnValue(response_keys) + results[server_name] = response_keys - raise ValueError("No verification key found for given key ids") + defer.returnValue(results) @defer.inlineCallbacks def get_server_verify_key_v1_direct(self, server_name, key_ids): @@ -462,8 +615,13 @@ class Keyring(object): Returns: A deferred that completes when the keys are stored. """ - for key_id, key in verify_keys.items(): - # TODO(markjh): Store whether the keys have expired. - yield self.store.store_server_verify_key( - server_name, server_name, key.time_added, key - ) + # TODO(markjh): Store whether the keys have expired. + yield defer.gatherResults( + [ + self.store.store_server_verify_key( + server_name, server_name, key.time_added, key + ) + for key_id, key in verify_keys.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) -- cgit 1.5.1 From f0dd568e1681b878d9e85a12c21bfa0d5540ff73 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Jun 2015 11:25:00 +0100 Subject: Wait for previous attempts at fetching keys for a given server before trying to fetch more --- synapse/crypto/keyring.py | 83 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 15 deletions(-) (limited to 'synapse/crypto') diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 873c9b40fa..aa74d4d0cb 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -27,6 +27,8 @@ from synapse.api.errors import SynapseError, Codes from synapse.util.retryutils import get_retry_limiter from synapse.util import unwrapFirstError +from synapse.util.async import ObservableDeferred + from OpenSSL import crypto from collections import namedtuple @@ -88,6 +90,8 @@ class Keyring(object): "Not signed with a supported algorithm", Codes.UNAUTHORIZED, )) + else: + deferreds[group_id] = defer.Deferred() group = KeyGroup(server_name, group_id, key_ids) @@ -133,10 +137,41 @@ class Keyring(object): Codes.UNAUTHORIZED, ) - deferreds.update(self.get_server_verify_keys( - group_id_to_group - )) + server_to_deferred = { + server_name: defer.Deferred() + for server_name, _ in server_and_json + } + + # We want to wait for any previous lookups to complete before + # proceeding. + wait_on_deferred = self.wait_for_previous_lookups( + [server_name for server_name, _ in server_and_json], + server_to_deferred, + ) + + # Actually start fetching keys. + wait_on_deferred.addBoth( + lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) + ) + + # When we've finished fetching all the keys for a given server_name, + # resolve the deferred passed to `wait_for_previous_lookups` so that + # any lookups waiting will proceed. + server_to_gids = {} + + def remove_deferreds(res, server_name, group_id): + server_to_gids[server_name].discard(group_id) + if not server_to_gids[server_name]: + server_to_deferred.pop(server_name).callback(None) + return res + for g_id, deferred in deferreds.items(): + server_name = group_id_to_group[g_id].server_name + server_to_gids.setdefault(server_name, set()).add(g_id) + deferred.addBoth(remove_deferreds, server_name, g_id) + + # Pass those keys to handle_key_deferred so that the json object + # signatures can be verified return [ handle_key_deferred( group_id_to_group[g_id], @@ -145,7 +180,30 @@ class Keyring(object): for g_id in group_ids ] - def get_server_verify_keys(self, group_id_to_group): + @defer.inlineCallbacks + def wait_for_previous_lookups(self, server_names, server_to_deferred): + """Waits for any previous key lookups for the given servers to finish. + + Args: + server_names (list): list of server_names we want to lookup + server_to_deferred (dict): server_name to deferred which gets + resolved once we've finished looking up keys for that server + """ + while True: + wait_on = [ + self.key_downloads[server_name] + for server_name in server_names + if server_name in self.key_downloads + ] + if wait_on: + yield defer.DeferredList(wait_on) + else: + break + + for server_name, deferred in server_to_deferred: + self.key_downloads[server_name] = ObservableDeferred(deferred) + + def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): """Takes a dict of KeyGroups and tries to find at least one key for each group. """ @@ -157,11 +215,6 @@ class Keyring(object): self.get_keys_from_server, # Then try directly ) - group_deferreds = { - group_id: defer.Deferred() - for group_id in group_id_to_group - } - @defer.inlineCallbacks def do_iterations(): merged_results = {} @@ -182,7 +235,7 @@ class Keyring(object): for group in group_id_to_group.values(): for key_id in group.key_ids: if key_id in merged_results[group.server_name]: - group_deferreds.pop(group.group_id).callback(( + group_id_to_deferred[group.group_id].callback(( group.group_id, group.server_name, key_id, @@ -205,7 +258,7 @@ class Keyring(object): } for group in missing_groups.values(): - group_deferreds.pop(group.group_id).errback(SynapseError( + group_id_to_deferred[group.group_id].errback(SynapseError( 401, "No key for %s with id %s" % ( group.server_name, group.key_ids, @@ -214,13 +267,13 @@ class Keyring(object): )) def on_err(err): - for deferred in group_deferreds.values(): - deferred.errback(err) - group_deferreds.clear() + for deferred in group_id_to_deferred.values(): + if not deferred.called: + deferred.errback(err) do_iterations().addErrback(on_err) - return group_deferreds + return group_id_to_deferred @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): -- cgit 1.5.1 From 64afbe6ccd19bb2ec94f3fbb3d91586202c924fd Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 8 Jul 2015 18:20:02 +0100 Subject: add new optional config for tls_certificate_chain_path for folks with intermediary SSL certs --- synapse/config/tls.py | 20 +++++++++++++++++--- synapse/crypto/context_factory.py | 2 ++ 2 files changed, 19 insertions(+), 3 deletions(-) (limited to 'synapse/crypto') diff --git a/synapse/config/tls.py b/synapse/config/tls.py index ecb2d42c1f..e04fe0d96c 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -25,9 +25,19 @@ GENERATE_DH_PARAMS = False class TlsConfig(Config): def read_config(self, config): self.tls_certificate = self.read_tls_certificate( - config.get("tls_certificate_path") + config.get("tls_certificate_path"), + "tls_certificate" ) + tls_certificate_chain_path = + config.get("tls_certificate_chain_path") + + if tls_certificate_chain_path and os.path.exists(tls_certificate_chain_path): + self.tls_certificate_chain = self.read_tls_certificate( + config.get("tls_certificate_chain_path"), + "tls_certificate_chain" + ) + self.no_tls = config.get("no_tls", False) if self.no_tls: @@ -45,6 +55,7 @@ class TlsConfig(Config): base_key_name = os.path.join(config_dir_path, server_name) tls_certificate_path = base_key_name + ".tls.crt" + tls_certificate_chain_path = base_key_name + ".tls.chain.crt" tls_private_key_path = base_key_name + ".tls.key" tls_dh_params_path = base_key_name + ".tls.dh" @@ -52,6 +63,9 @@ class TlsConfig(Config): # PEM encoded X509 certificate for TLS tls_certificate_path: "%(tls_certificate_path)s" + # PEM encoded X509 intermediary certificate file for TLS (optional) + # tls_certificate_chain_path: "%(tls_certificate_chain_path)s" + # PEM encoded private key for TLS tls_private_key_path: "%(tls_private_key_path)s" @@ -62,8 +76,8 @@ class TlsConfig(Config): no_tls: False """ % locals() - def read_tls_certificate(self, cert_path): - cert_pem = self.read_file(cert_path, "tls_certificate") + def read_tls_certificate(self, cert_path, config_name): + cert_pem = self.read_file(cert_path, config_name) return crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) def read_tls_private_key(self, private_key_path): diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 2f8618a0df..ea5dd1e7d3 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -38,6 +38,8 @@ class ServerContextFactory(ssl.ContextFactory): logger.exception("Failed to enable eliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.use_certificate(config.tls_certificate) + if config.tls_certificate_chain: + context.use_certificate_chain_file(config.tls_certificate_chain) if not config.no_tls: context.use_privatekey(config.tls_private_key) -- cgit 1.5.1 From 19fa3731aee498159cbc475f8b29f66dacb6aba6 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 8 Jul 2015 18:53:41 +0100 Subject: typo --- synapse/crypto/context_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/crypto') diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index ea5dd1e7d3..324dc31fe4 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -35,7 +35,7 @@ class ServerContextFactory(ssl.ContextFactory): _ecCurve = _OpenSSLECCurve(_defaultCurveName) _ecCurve.addECKeyToContext(context) except: - logger.exception("Failed to enable eliptic curve for TLS") + logger.exception("Failed to enable elliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.use_certificate(config.tls_certificate) if config.tls_certificate_chain: -- cgit 1.5.1 From f26a3df1bf6cabee28f5f91778082c0f26b2378c Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Wed, 8 Jul 2015 21:33:02 +0100 Subject: oops, context.tls_certificate_chain_file() expects a file, not a certificate. --- synapse/config/tls.py | 5 +---- synapse/crypto/context_factory.py | 4 ++-- 2 files changed, 3 insertions(+), 6 deletions(-) (limited to 'synapse/crypto') diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 945af38053..5fef63846d 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -32,10 +32,7 @@ class TlsConfig(Config): tls_certificate_chain_path = config.get("tls_certificate_chain_path") if tls_certificate_chain_path and os.path.exists(tls_certificate_chain_path): - self.tls_certificate_chain = self.read_tls_certificate( - config.get("tls_certificate_chain_path"), - "tls_certificate_chain" - ) + self.tls_certificate_chain_file = tls_certificate_chain_path else: self.tls_certificate_chain = None diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 324dc31fe4..d515007ca0 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -38,8 +38,8 @@ class ServerContextFactory(ssl.ContextFactory): logger.exception("Failed to enable elliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.use_certificate(config.tls_certificate) - if config.tls_certificate_chain: - context.use_certificate_chain_file(config.tls_certificate_chain) + if config.tls_certificate_chain_file: + context.use_certificate_chain_file(config.tls_certificate_chain_file) if not config.no_tls: context.use_privatekey(config.tls_private_key) -- cgit 1.5.1 From fb8d2862c1d7582096b5f8bd6194dcbe8e1afc01 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Thu, 9 Jul 2015 00:45:41 +0100 Subject: remove the tls_certificate_chain_path param and simply support tls_certificate_path pointing to a file containing a chain of certificates --- synapse/config/tls.py | 30 +++++++++--------------------- synapse/crypto/context_factory.py | 4 +--- 2 files changed, 10 insertions(+), 24 deletions(-) (limited to 'synapse/crypto') diff --git a/synapse/config/tls.py b/synapse/config/tls.py index de57d0d0ed..e136d13713 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -25,16 +25,9 @@ GENERATE_DH_PARAMS = False class TlsConfig(Config): def read_config(self, config): self.tls_certificate = self.read_tls_certificate( - config.get("tls_certificate_path"), - "tls_certificate" + config.get("tls_certificate_path") ) - - tls_certificate_chain_path = config.get("tls_certificate_chain_path") - - if tls_certificate_chain_path and os.path.exists(tls_certificate_chain_path): - self.tls_certificate_chain_file = tls_certificate_chain_path - else: - self.tls_certificate_chain = None + self.tls_certificate_file = config.get("tls_certificate_path"); self.no_tls = config.get("no_tls", False) @@ -53,22 +46,17 @@ class TlsConfig(Config): base_key_name = os.path.join(config_dir_path, server_name) tls_certificate_path = base_key_name + ".tls.crt" - tls_certificate_chain_path = base_key_name + ".tls.chain.crt" tls_private_key_path = base_key_name + ".tls.key" tls_dh_params_path = base_key_name + ".tls.dh" return """\ - # PEM encoded X509 certificate for TLS + # PEM encoded X509 certificate for TLS. + # You can replace the self-signed certificate that synapse + # autogenerates on launch with your own SSL certificate + key pair + # if you like. Any required intermediary certificates can be + # appended after the primary certificate in hierarchical order. tls_certificate_path: "%(tls_certificate_path)s" - # PEM encoded X509 intermediary certificate file for TLS (optional) - # This *must* be a concatenation of the tls_certificate pointed to - # by tls_certificate_path followed by the intermediary certificates - # in hierarchical order. If specified this option overrides the - # tls_certificate_path parameter. - # - # tls_certificate_chain_path: "%(tls_certificate_chain_path)s" - # PEM encoded private key for TLS tls_private_key_path: "%(tls_private_key_path)s" @@ -79,8 +67,8 @@ class TlsConfig(Config): no_tls: False """ % locals() - def read_tls_certificate(self, cert_path, config_name): - cert_pem = self.read_file(cert_path, config_name) + def read_tls_certificate(self, cert_path): + cert_pem = self.read_file(cert_path, "tls_certificate") return crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem) def read_tls_private_key(self, private_key_path): diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index d515007ca0..c4390f3b2b 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -37,9 +37,7 @@ class ServerContextFactory(ssl.ContextFactory): except: logger.exception("Failed to enable elliptic curve for TLS") context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) - context.use_certificate(config.tls_certificate) - if config.tls_certificate_chain_file: - context.use_certificate_chain_file(config.tls_certificate_chain_file) + context.use_certificate_chain_file(config.tls_certificate_file) if not config.no_tls: context.use_privatekey(config.tls_private_key) -- cgit 1.5.1