diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index aff69c5f83..873c9b40fa 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
-
-from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError
from OpenSSL import crypto
+from collections import namedtuple
import urllib
import hashlib
import logging
@@ -38,6 +38,9 @@ import logging
logger = logging.getLogger(__name__)
+KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+
+
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -49,141 +52,274 @@ class Keyring(object):
self.key_downloads = {}
- @defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
- logger.debug("Verifying for %s", server_name)
- key_ids = signature_ids(json_object, server_name)
- if not key_ids:
- raise SynapseError(
- 400,
- "Not signed with a supported algorithm",
- Codes.UNAUTHORIZED,
- )
- try:
- verify_key = yield self.get_server_verify_key(server_name, key_ids)
- except IOError as e:
- logger.warn(
- "Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.warn(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, key_ids),
- Codes.UNAUTHORIZED,
- )
+ return self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except:
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
- ),
- Codes.UNAUTHORIZED,
- )
+ def verify_json_objects_for_server(self, server_and_json):
+ """Bulk verfies signatures of json objects, bulk fetching keys as
+ necessary.
- @defer.inlineCallbacks
- def get_server_verify_key(self, server_name, key_ids):
- """Finds a verification key for the server with one of the key ids.
- Trys to fetch the key from a trusted perspective server first.
Args:
- server_name(str): The name of the server to fetch a key for.
- keys_ids (list of str): The key_ids to check for.
+ server_and_json (list): List of pairs of (server_name, json_object)
+
+ Returns:
+ list of deferreds indicating success or failure to verify each
+ json object's signature for the given server_name.
"""
- cached = yield self.store.get_server_verify_keys(server_name, key_ids)
+ group_id_to_json = {}
+ group_id_to_group = {}
+ group_ids = []
+
+ next_group_id = 0
+ deferreds = {}
+
+ for server_name, json_object in server_and_json:
+ logger.debug("Verifying for %s", server_name)
+ group_id = next_group_id
+ next_group_id += 1
+ group_ids.append(group_id)
+
+ key_ids = signature_ids(json_object, server_name)
+ if not key_ids:
+ deferreds[group_id] = defer.fail(SynapseError(
+ 400,
+ "Not signed with a supported algorithm",
+ Codes.UNAUTHORIZED,
+ ))
+
+ group = KeyGroup(server_name, group_id, key_ids)
- if cached:
- defer.returnValue(cached[0])
- return
+ group_id_to_group[group_id] = group
+ group_id_to_json[group_id] = json_object
- download = self.key_downloads.get(server_name)
+ @defer.inlineCallbacks
+ def handle_key_deferred(group, deferred):
+ server_name = group.server_name
+ try:
+ _, _, key_id, verify_key = yield deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = group_id_to_json[group.group_id]
- if download is None:
- download = self._get_server_verify_key_impl(server_name, key_ids)
- download = ObservableDeferred(
- download,
- consumeErrors=True
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except:
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s" % (
+ server_name, verify_key.alg, verify_key.version
+ ),
+ Codes.UNAUTHORIZED,
+ )
+
+ deferreds.update(self.get_server_verify_keys(
+ group_id_to_group
+ ))
+
+ return [
+ handle_key_deferred(
+ group_id_to_group[g_id],
+ deferreds[g_id],
)
- self.key_downloads[server_name] = download
+ for g_id in group_ids
+ ]
- @download.addBoth
- def callback(ret):
- del self.key_downloads[server_name]
- return ret
+ def get_server_verify_keys(self, group_id_to_group):
+ """Takes a dict of KeyGroups and tries to find at least one key for
+ each group.
+ """
+
+ # These are functions that produce keys given a list of key ids
+ key_fetch_fns = (
+ self.get_keys_from_store, # First try the local store
+ self.get_keys_from_perspectives, # Then try via perspectives
+ self.get_keys_from_server, # Then try directly
+ )
+
+ group_deferreds = {
+ group_id: defer.Deferred()
+ for group_id in group_id_to_group
+ }
+
+ @defer.inlineCallbacks
+ def do_iterations():
+ merged_results = {}
+
+ missing_keys = {
+ group.server_name: key_id
+ for group in group_id_to_group.values()
+ for key_id in group.key_ids
+ }
+
+ for fn in key_fetch_fns:
+ results = yield fn(missing_keys.items())
+ merged_results.update(results)
+
+ # We now need to figure out which groups we have keys for
+ # and which we don't
+ missing_groups = {}
+ for group in group_id_to_group.values():
+ for key_id in group.key_ids:
+ if key_id in merged_results[group.server_name]:
+ group_deferreds.pop(group.group_id).callback((
+ group.group_id,
+ group.server_name,
+ key_id,
+ merged_results[group.server_name][key_id],
+ ))
+ break
+ else:
+ missing_groups.setdefault(
+ group.server_name, []
+ ).append(group)
+
+ if not missing_groups:
+ break
+
+ missing_keys = {
+ server_name: set(
+ key_id for group in groups for key_id in group.key_ids
+ )
+ for server_name, groups in missing_groups.items()
+ }
- r = yield download.observe()
- defer.returnValue(r)
+ for group in missing_groups.values():
+ group_deferreds.pop(group.group_id).errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ group.server_name, group.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
+
+ def on_err(err):
+ for deferred in group_deferreds.values():
+ deferred.errback(err)
+ group_deferreds.clear()
+
+ do_iterations().addErrback(on_err)
+
+ return group_deferreds
@defer.inlineCallbacks
- def _get_server_verify_key_impl(self, server_name, key_ids):
- keys = None
+ def get_keys_from_store(self, server_name_and_key_ids):
+ res = yield defer.gatherResults(
+ [
+ self.store.get_server_verify_keys(server_name, key_ids)
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ defer.returnValue(dict(zip(
+ [server_name for server_name, _ in server_name_and_key_ids],
+ res
+ )))
+ @defer.inlineCallbacks
+ def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
- server_name, key_ids, perspective_name, perspective_keys
+ server_name_and_key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except Exception as e:
- logging.info(
- "Unable to getting key %r for %r from %r: %s %s",
- key_ids, server_name, perspective_name,
+ logger.exception(
+ "Unable to get key from %r: %s %s",
+ perspective_name,
type(e).__name__, str(e.message),
)
+ defer.returnValue({})
- perspective_results = yield defer.gatherResults([
- get_key(p_name, p_keys)
- for p_name, p_keys in self.perspective_servers.items()
- ])
+ results = yield defer.gatherResults(
+ [
+ get_key(p_name, p_keys)
+ for p_name, p_keys in self.perspective_servers.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
- for results in perspective_results:
- if results is not None:
- keys = results
+ union_of_keys = {}
+ for result in results:
+ for server_name, keys in result.items():
+ union_of_keys.setdefault(server_name, {}).update(keys)
- limiter = yield get_retry_limiter(
- server_name,
- self.clock,
- self.store,
- )
+ defer.returnValue(union_of_keys)
- with limiter:
- if not keys:
+ @defer.inlineCallbacks
+ def get_keys_from_server(self, server_name_and_key_ids):
+ @defer.inlineCallbacks
+ def get_key(server_name, key_ids):
+ limiter = yield get_retry_limiter(
+ server_name,
+ self.clock,
+ self.store,
+ )
+ with limiter:
+ keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
- logging.info(
+ logger.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
- if not keys:
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
+ if not keys:
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
+
+ keys = {server_name: keys}
+
+ defer.returnValue(keys)
+
+ results = yield defer.gatherResults(
+ [
+ get_key(server_name, key_ids)
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
- for key_id in key_ids:
- if key_id in keys:
- defer.returnValue(keys[key_id])
- return
- raise ValueError("No verification key found for given key ids")
+ merged = {}
+ for result in results:
+ merged.update(result)
+
+ defer.returnValue({
+ server_name: keys
+ for server_name, keys in merged.items()
+ if keys
+ })
@defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(self, server_name, key_ids,
+ def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
@@ -204,6 +340,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
+ for server_name, key_ids in server_names_and_key_ids
}
},
)
@@ -243,23 +380,29 @@ class Keyring(object):
" server %r" % (perspective_name,)
)
- response_keys = yield self.process_v2_response(
- server_name, perspective_name, response
+ processed_response = yield self.process_v2_response(
+ perspective_name, response
)
- keys.update(response_keys)
+ for server_name, response_keys in processed_response.items():
+ keys.setdefault(server_name, {}).update(response_keys)
- yield self.store_keys(
- server_name=server_name,
- from_server=perspective_name,
- verify_keys=keys,
- )
+ yield defer.gatherResults(
+ [
+ self.store_keys(
+ server_name=server_name,
+ from_server=perspective_name,
+ verify_keys=response_keys,
+ )
+ for server_name, response_keys in keys.items()
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
-
keys = {}
for requested_key_id in key_ids:
@@ -295,25 +438,30 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
- server_name=server_name,
from_server=server_name,
- requested_id=requested_key_id,
+ requested_ids=[requested_key_id],
response_json=response,
)
keys.update(response_keys)
- yield self.store_keys(
- server_name=server_name,
- from_server=server_name,
- verify_keys=keys,
- )
+ yield defer.gatherResults(
+ [
+ self.store_keys(
+ server_name=key_server_name,
+ from_server=server_name,
+ verify_keys=verify_keys,
+ )
+ for key_server_name, verify_keys in keys.items()
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
- def process_v2_response(self, server_name, from_server, response_json,
- requested_id=None):
+ def process_v2_response(self, from_server, response_json,
+ requested_ids=[]):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -335,6 +483,8 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
+ results = {}
+ server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(
@@ -357,28 +507,31 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
- updated_key_ids = set()
- if requested_id is not None:
- updated_key_ids.add(requested_id)
+ updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
- for key_id in updated_key_ids:
- yield self.store.store_server_keys_json(
- server_name=server_name,
- key_id=key_id,
- from_server=server_name,
- ts_now_ms=time_now_ms,
- ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
- )
+ yield defer.gatherResults(
+ [
+ self.store.store_server_keys_json(
+ server_name=server_name,
+ key_id=key_id,
+ from_server=server_name,
+ ts_now_ms=time_now_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
+ )
+ for key_id in updated_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
- defer.returnValue(response_keys)
+ results[server_name] = response_keys
- raise ValueError("No verification key found for given key ids")
+ defer.returnValue(results)
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
@@ -462,8 +615,13 @@ class Keyring(object):
Returns:
A deferred that completes when the keys are stored.
"""
- for key_id, key in verify_keys.items():
- # TODO(markjh): Store whether the keys have expired.
- yield self.store.store_server_verify_key(
- server_name, server_name, key.time_added, key
- )
+ # TODO(markjh): Store whether the keys have expired.
+ yield defer.gatherResults(
+ [
+ self.store.store_server_verify_key(
+ server_name, server_name, key.time_added, key
+ )
+ for key_id, key in verify_keys.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
|