diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index aff69c5f83..eb94cd5b75 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
-
-from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError
from OpenSSL import crypto
+from collections import namedtuple
import urllib
import hashlib
import logging
@@ -38,6 +38,9 @@ import logging
logger = logging.getLogger(__name__)
+KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+
+
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -49,141 +52,257 @@ class Keyring(object):
self.key_downloads = {}
- @defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
- logger.debug("Verifying for %s", server_name)
- key_ids = signature_ids(json_object, server_name)
- if not key_ids:
- raise SynapseError(
- 400,
- "Not signed with a supported algorithm",
- Codes.UNAUTHORIZED,
- )
- try:
- verify_key = yield self.get_server_verify_key(server_name, key_ids)
- except IOError as e:
- logger.warn(
- "Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.warn(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, key_ids),
- Codes.UNAUTHORIZED,
- )
+ return self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
+
+ def verify_json_objects_for_server(self, server_and_json):
+ server_to_key_groupings = {}
+ group_id_to_json = {}
+ group_id_to_group = {}
+ group_ids = []
+
+ next_group_id = 0
+
+ for server_name, json_object in server_and_json:
+ logger.debug("Verifying for %s", server_name)
+ key_ids = signature_ids(json_object, server_name)
+ if not key_ids:
+ raise SynapseError(
+ 400,
+ "Not signed with a supported algorithm",
+ Codes.UNAUTHORIZED,
+ )
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except:
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
- ),
- Codes.UNAUTHORIZED,
- )
+ group_id = next_group_id
+ next_group_id += 1
+ group_ids.append(group_id)
- @defer.inlineCallbacks
- def get_server_verify_key(self, server_name, key_ids):
- """Finds a verification key for the server with one of the key ids.
- Trys to fetch the key from a trusted perspective server first.
- Args:
- server_name(str): The name of the server to fetch a key for.
- keys_ids (list of str): The key_ids to check for.
- """
- cached = yield self.store.get_server_verify_keys(server_name, key_ids)
+ group = KeyGroup(server_name, group_id, key_ids)
+
+ group_id_to_group[group_id] = group
+ group_id_to_json[group_id] = json_object
+ server_to_key_groupings.setdefault(server_name, []).append(group)
+
+ @defer.inlineCallbacks
+ def handle_key_deferred(group, deferred):
+ server_name = group.server_name
+ try:
+ _, _, key_id, verify_key = yield deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = group_id_to_json[group.group_id]
+
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except:
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s" % (
+ server_name, verify_key.alg, verify_key.version
+ ),
+ Codes.UNAUTHORIZED,
+ )
- if cached:
- defer.returnValue(cached[0])
- return
+ deferreds = self.get_server_verify_keys(
+ group_id_to_group
+ )
- download = self.key_downloads.get(server_name)
+ logger.info(
+ "Deferred count: %d vs. %d",
+ len(deferreds.items()),
+ len(server_and_json)
+ )
- if download is None:
- download = self._get_server_verify_key_impl(server_name, key_ids)
- download = ObservableDeferred(
- download,
- consumeErrors=True
+ return [
+ handle_key_deferred(
+ group_id_to_group[g_id],
+ deferreds[g_id],
)
- self.key_downloads[server_name] = download
+ for g_id in group_ids
+ ]
+
+ def get_server_verify_keys(self, group_id_to_group):
+ merged_results = {}
- @download.addBoth
- def callback(ret):
- del self.key_downloads[server_name]
- return ret
+ fns = (
+ self.get_keys_from_store, # First try the local store
+ self.get_keys_from_perspectives, # Then try via perspectives
+ self.get_keys_from_server, # Then try directly
+ )
- r = yield download.observe()
- defer.returnValue(r)
+ group_deferreds = {
+ group_id: defer.Deferred()
+ for group_id in group_id_to_group
+ }
+
+ @defer.inlineCallbacks
+ def do_iterations():
+ missing_keys = {
+ group.server_name: key_id
+ for group in group_id_to_group.values()
+ for key_id in group.key_ids
+ }
+
+ for fn in fns:
+ results = yield fn(missing_keys.items())
+ merged_results.update(results)
+
+ missing_groups = {}
+ for group in group_id_to_group.values():
+ for key_id in group.key_ids:
+ if key_id in merged_results[group.server_name]:
+ group_deferreds.pop(group.group_id).callback((
+ group.group_id,
+ group.server_name,
+ key_id,
+ merged_results[group.server_name][key_id],
+ ))
+ break
+ else:
+ missing_groups.setdefault(
+ group.server_name, []
+ ).append(group)
+
+ if not missing_groups:
+ break
+
+ missing_keys = {
+ server_name: set(
+ key_id for group in groups for key_id in group.key_ids
+ )
+ for server_name, groups in missing_groups.items()
+ }
+
+ for group in missing_groups.values():
+ group_deferreds.pop(group.group_id).errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ group.server_name, group.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
+
+ def on_err(err):
+ for deferred in group_deferreds.values():
+ deferred.errback(err)
+ group_deferreds.clear()
+
+ do_iterations().addErrback(on_err)
+
+ return group_deferreds
@defer.inlineCallbacks
- def _get_server_verify_key_impl(self, server_name, key_ids):
- keys = None
+ def get_keys_from_store(self, server_name_and_key_ids):
+ res = yield defer.gatherResults(
+ [
+ self.store.get_server_verify_keys(server_name, key_ids)
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ defer.returnValue(dict(zip(
+ [server_name for server_name, _ in server_name_and_key_ids],
+ res
+ )))
+ @defer.inlineCallbacks
+ def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
- server_name, key_ids, perspective_name, perspective_keys
+ server_name_and_key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except Exception as e:
- logging.info(
- "Unable to getting key %r for %r from %r: %s %s",
- key_ids, server_name, perspective_name,
+ logger.info(
+ "Unable to get key from %r: %s %s",
+ perspective_name,
type(e).__name__, str(e.message),
)
- perspective_results = yield defer.gatherResults([
+ results = yield defer.gatherResults([
get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
])
- for results in perspective_results:
- if results is not None:
- keys = results
+ union_of_keys = {}
+ for result in results:
+ for server_name, keys in results.items():
+ union_of_keys.setdefault(server_name, {}).update(keys)
- limiter = yield get_retry_limiter(
- server_name,
- self.clock,
- self.store,
- )
+ defer.returnValue(union_of_keys)
- with limiter:
- if not keys:
+ @defer.inlineCallbacks
+ def get_keys_from_server(self, server_name_and_key_ids):
+ @defer.inlineCallbacks
+ def get_key(server_name, key_ids):
+ limiter = yield get_retry_limiter(
+ server_name,
+ self.clock,
+ self.store,
+ )
+ with limiter:
+ keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
- logging.info(
+ logger.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
- if not keys:
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
+ if not keys:
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
+
+ keys = {server_name: keys}
+
+ defer.returnValue(keys)
+
+ results = yield defer.gatherResults([
+ get_key(server_name, key_ids)
+ for server_name, key_ids in server_name_and_key_ids
+ ])
+
+ merged = {}
+ for result in results:
+ merged.update(result)
- for key_id in key_ids:
- if key_id in keys:
- defer.returnValue(keys[key_id])
- return
- raise ValueError("No verification key found for given key ids")
+ defer.returnValue({
+ server_name: keys
+ for server_name, keys in merged.items()
+ if keys
+ })
@defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(self, server_name, key_ids,
+ def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
@@ -204,6 +323,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
+ for server_name, key_ids in server_names_and_key_ids
}
},
)
@@ -243,23 +363,24 @@ class Keyring(object):
" server %r" % (perspective_name,)
)
- response_keys = yield self.process_v2_response(
- server_name, perspective_name, response
+ processed_response = yield self.process_v2_response(
+ perspective_name, response
)
- keys.update(response_keys)
+ for server_name, response_keys in processed_response:
+ keys.setdefault(server_name, {}).update(response_keys)
- yield self.store_keys(
- server_name=server_name,
- from_server=perspective_name,
- verify_keys=keys,
- )
+ for server_name, response_keys in keys.items():
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=perspective_name,
+ verify_keys=keys,
+ )
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
-
keys = {}
for requested_key_id in key_ids:
@@ -295,25 +416,25 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
- server_name=server_name,
from_server=server_name,
- requested_id=requested_key_id,
+ requested_ids=[requested_key_id],
response_json=response,
)
keys.update(response_keys)
- yield self.store_keys(
- server_name=server_name,
- from_server=server_name,
- verify_keys=keys,
- )
+ for server_name, verify_keys in keys.items():
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=server_name,
+ verify_keys=verify_keys,
+ )
defer.returnValue(keys)
@defer.inlineCallbacks
- def process_v2_response(self, server_name, from_server, response_json,
- requested_id=None):
+ def process_v2_response(self, from_server, response_json,
+ requested_ids=[]):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -335,50 +456,50 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
- for key_id in response_json["signatures"].get(server_name, {}):
- if key_id not in response_json["verify_keys"]:
- raise ValueError(
- "Key response must include verification keys for all"
- " signatures"
- )
- if key_id in verify_keys:
- verify_signed_json(
- response_json,
- server_name,
- verify_keys[key_id]
- )
+ results = {}
+ for server_name, keys_dict in response_json["signatures"].items():
+ for key_id in keys_dict:
+ if key_id not in response_json["verify_keys"]:
+ raise ValueError(
+ "Key response must include verification keys for all"
+ " signatures"
+ )
+ if key_id in verify_keys:
+ verify_signed_json(
+ response_json,
+ server_name,
+ verify_keys[key_id]
+ )
- signed_key_json = sign_json(
- response_json,
- self.config.server_name,
- self.config.signing_key[0],
- )
+ signed_key_json = sign_json(
+ response_json,
+ self.config.server_name,
+ self.config.signing_key[0],
+ )
- signed_key_json_bytes = encode_canonical_json(signed_key_json)
- ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
+ signed_key_json_bytes = encode_canonical_json(signed_key_json)
+ ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
- updated_key_ids = set()
- if requested_id is not None:
- updated_key_ids.add(requested_id)
- updated_key_ids.update(verify_keys)
- updated_key_ids.update(old_verify_keys)
+ updated_key_ids = set(requested_ids)
+ updated_key_ids.update(verify_keys)
+ updated_key_ids.update(old_verify_keys)
- response_keys.update(verify_keys)
- response_keys.update(old_verify_keys)
+ response_keys.update(verify_keys)
+ response_keys.update(old_verify_keys)
- for key_id in updated_key_ids:
- yield self.store.store_server_keys_json(
- server_name=server_name,
- key_id=key_id,
- from_server=server_name,
- ts_now_ms=time_now_ms,
- ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
- )
+ for key_id in updated_key_ids:
+ yield self.store.store_server_keys_json(
+ server_name=server_name,
+ key_id=key_id,
+ from_server=server_name,
+ ts_now_ms=time_now_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
+ )
- defer.returnValue(response_keys)
+ results[server_name] = response_keys
- raise ValueError("No verification key found for given key ids")
+ defer.returnValue(results)
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 299493af91..407e0f815c 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -99,35 +99,50 @@ class FederationBase(object):
defer.returnValue(signed_pdus)
- @defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
- """Throws a SynapseError if the PDU does not have the correct
+ return self._check_sigs_and_hashes([pdu])[0]
+
+ def _check_sigs_and_hashes(self, pdus):
+ """Throws a SynapseError if a PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
- # Check signatures are correct.
- redacted_event = prune_event(pdu)
- redacted_pdu_json = redacted_event.get_pdu_json()
- try:
- yield self.keyring.verify_json_for_server(
- pdu.origin, redacted_pdu_json
- )
- except SynapseError:
+ redacted_pdus = [
+ prune_event(pdu)
+ for pdu in pdus
+ ]
+
+ deferreds = self.keyring.verify_json_objects_for_server([
+ (p.origin, p.get_pdu_json())
+ for p in redacted_pdus
+ ])
+
+ def callback(_, pdu, redacted):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+ return pdu
+
+ def errback(failure, pdu):
+ failure.trap(SynapseError)
logger.warn(
"Signature check failed for %s",
pdu.event_id,
)
- raise
+ return failure
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting.",
- pdu.event_id,
+ for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
+ deferred.addCallbacks(
+ callback, errback,
+ callbackArgs=[pdu, redacted],
+ errbackArgs=[pdu],
)
- defer.returnValue(redacted_event)
- defer.returnValue(pdu)
+ return deferreds
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7ee3c66bf2..47d71542e4 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -166,10 +166,7 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield defer.gatherResults(
- [self._check_sigs_and_hash(pdu) for pdu in pdus],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ pdus[:] = yield self._check_sigs_and_hashes(pdus)
defer.returnValue(pdus)
@@ -230,7 +227,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- pdu = yield self._check_sigs_and_hash(pdu)
+ pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
@@ -402,7 +399,7 @@ class FederationClient(FederationBase):
except CodeMessageException:
raise
except Exception as e:
- logger.warn(
+ logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 2902e35181..4f990b7792 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -101,7 +101,11 @@ class KeyStore(SQLBaseStore):
(list of VerifyKey): The verification keys.
"""
keys = yield self.get_all_server_verify_keys(server_name)
- defer.returnValue([keys[k] for k in key_ids if k in keys])
+ defer.returnValue({
+ k: keys[k]
+ for k in key_ids
+ if k in keys and keys[k]
+ })
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|