diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f4918d1bc6..e94e71bdad 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from collections import defaultdict
import six
from six import raise_from
@@ -45,6 +46,7 @@ from synapse.api.errors import (
)
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.logcontext import (
LoggingContext,
PreserveLoggingContext,
@@ -70,7 +72,10 @@ class VerifyKeyRequest(object):
json_object(dict): The JSON object to verify.
- deferred(Deferred[str, str, nacl.signing.VerifyKey]):
+ minimum_valid_until_ts (int): time at which we require the signing key to
+ be valid. (0 implies we don't care)
+
+ key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
logcontext.
@@ -82,7 +87,8 @@ class VerifyKeyRequest(object):
server_name = attr.ib()
key_ids = attr.ib()
json_object = attr.ib()
- deferred = attr.ib()
+ minimum_valid_until_ts = attr.ib()
+ key_ready = attr.ib(default=attr.Factory(defer.Deferred))
class KeyLookupError(ValueError):
@@ -90,14 +96,16 @@ class KeyLookupError(ValueError):
class Keyring(object):
- def __init__(self, hs):
+ def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock()
- self._key_fetchers = (
- StoreKeyFetcher(hs),
- PerspectivesKeyFetcher(hs),
- ServerKeyFetcher(hs),
- )
+ if key_fetchers is None:
+ key_fetchers = (
+ StoreKeyFetcher(hs),
+ PerspectivesKeyFetcher(hs),
+ ServerKeyFetcher(hs),
+ )
+ self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
@@ -106,9 +114,25 @@ class Keyring(object):
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}
- def verify_json_for_server(self, server_name, json_object):
+ def verify_json_for_server(self, server_name, json_object, validity_time):
+ """Verify that a JSON object has been signed by a given server
+
+ Args:
+ server_name (str): name of the server which must have signed this object
+
+ json_object (dict): object to be checked
+
+ validity_time (int): timestamp at which we require the signing key to
+ be valid. (0 implies we don't care)
+
+ Returns:
+ Deferred[None]: completes if the the object was correctly signed, otherwise
+ errbacks with an error
+ """
+ req = server_name, json_object, validity_time
+
return logcontext.make_deferred_yieldable(
- self.verify_json_objects_for_server([(server_name, json_object)])[0]
+ self.verify_json_objects_for_server((req,))[0]
)
def verify_json_objects_for_server(self, server_and_json):
@@ -116,10 +140,12 @@ class Keyring(object):
necessary.
Args:
- server_and_json (list): List of pairs of (server_name, json_object)
+ server_and_json (iterable[Tuple[str, dict, int]):
+ Iterable of triplets of (server_name, json_object, validity_time)
+ validity_time is a timestamp at which the signing key must be valid.
Returns:
- List<Deferred>: for each input pair, a deferred indicating success
+ List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
@@ -128,12 +154,12 @@ class Keyring(object):
verify_requests = []
handle = preserve_fn(_handle_key_deferred)
- def process(server_name, json_object):
+ def process(server_name, json_object, validity_time):
"""Process an entry in the request list
- Given a (server_name, json_object) pair from the request list,
- adds a key request to verify_requests, and returns a deferred which will
- complete or fail (in the sentinel context) when verification completes.
+ Given a (server_name, json_object, validity_time) triplet from the request
+ list, adds a key request to verify_requests, and returns a deferred which
+ will complete or fail (in the sentinel context) when verification completes.
"""
key_ids = signature_ids(json_object, server_name)
@@ -144,11 +170,16 @@ class Keyring(object):
)
)
- logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
+ logger.debug(
+ "Verifying for %s with key_ids %s, min_validity %i",
+ server_name,
+ key_ids,
+ validity_time,
+ )
# add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest(
- server_name, key_ids, json_object, defer.Deferred()
+ server_name, key_ids, json_object, validity_time
)
verify_requests.append(verify_request)
@@ -160,8 +191,8 @@ class Keyring(object):
return handle(verify_request)
results = [
- process(server_name, json_object)
- for server_name, json_object in server_and_json
+ process(server_name, json_object, validity_time)
+ for server_name, json_object, validity_time in server_and_json
]
if verify_requests:
@@ -173,7 +204,7 @@ class Keyring(object):
def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request
- Once each fetch completes, verify_request.deferred will be resolved.
+ Once each fetch completes, verify_request.key_ready will be resolved.
Args:
verify_requests (List[VerifyKeyRequest]):
@@ -219,7 +250,7 @@ class Keyring(object):
return res
for verify_request in verify_requests:
- verify_request.deferred.addBoth(remove_deferreds, verify_request)
+ verify_request.key_ready.addBoth(remove_deferreds, verify_request)
except Exception:
logger.exception("Error starting key lookups")
@@ -272,7 +303,7 @@ class Keyring(object):
def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
- For each verify_request, verify_request.deferred is called back with
+ For each verify_request, verify_request.key_ready is called back with
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
with a SynapseError if none of the keys are found.
@@ -281,7 +312,7 @@ class Keyring(object):
"""
remaining_requests = set(
- (rq for rq in verify_requests if not rq.deferred.called)
+ (rq for rq in verify_requests if not rq.key_ready.called)
)
@defer.inlineCallbacks
@@ -295,11 +326,15 @@ class Keyring(object):
# look for any requests which weren't satisfied
with PreserveLoggingContext():
for verify_request in remaining_requests:
- verify_request.deferred.errback(
+ verify_request.key_ready.errback(
SynapseError(
401,
- "No key for %s with id %s"
- % (verify_request.server_name, verify_request.key_ids),
+ "No key for %s with ids in %s (min_validity %i)"
+ % (
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
+ ),
Codes.UNAUTHORIZED,
)
)
@@ -311,8 +346,8 @@ class Keyring(object):
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
for verify_request in remaining_requests:
- if not verify_request.deferred.called:
- verify_request.deferred.errback(err)
+ if not verify_request.key_ready.called:
+ verify_request.key_ready.errback(err)
run_in_background(do_iterations).addErrback(on_err)
@@ -323,18 +358,28 @@ class Keyring(object):
Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
- Any successfully-completed requests will be reomved from the list.
+ Any successfully-completed requests will be removed from the list.
"""
- # dict[str, set(str)]: keys to fetch for each server
- missing_keys = {}
+ # dict[str, dict[str, int]]: keys to fetch.
+ # server_name -> key_id -> min_valid_ts
+ missing_keys = defaultdict(dict)
+
for verify_request in remaining_requests:
# any completed requests should already have been removed
- assert not verify_request.deferred.called
- missing_keys.setdefault(verify_request.server_name, set()).update(
- verify_request.key_ids
- )
+ assert not verify_request.key_ready.called
+ keys_for_server = missing_keys[verify_request.server_name]
- results = yield fetcher.get_keys(missing_keys.items())
+ for key_id in verify_request.key_ids:
+ # If we have several requests for the same key, then we only need to
+ # request that key once, but we should do so with the greatest
+ # min_valid_until_ts of the requests, so that we can satisfy all of
+ # the requests.
+ keys_for_server[key_id] = max(
+ keys_for_server.get(key_id, -1),
+ verify_request.minimum_valid_until_ts
+ )
+
+ results = yield fetcher.get_keys(missing_keys)
completed = list()
for verify_request in remaining_requests:
@@ -344,25 +389,34 @@ class Keyring(object):
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
- key = result_keys.get(key_id)
- if key:
- with PreserveLoggingContext():
- verify_request.deferred.callback(
- (server_name, key_id, key.verify_key)
- )
- completed.append(verify_request)
- break
+ fetch_key_result = result_keys.get(key_id)
+ if not fetch_key_result:
+ # we didn't get a result for this key
+ continue
+
+ if (
+ fetch_key_result.valid_until_ts
+ < verify_request.minimum_valid_until_ts
+ ):
+ # key was not valid at this point
+ continue
+
+ with PreserveLoggingContext():
+ verify_request.key_ready.callback(
+ (server_name, key_id, fetch_key_result.verify_key)
+ )
+ completed.append(verify_request)
+ break
remaining_requests.difference_update(completed)
class KeyFetcher(object):
- def get_keys(self, server_name_and_key_ids):
+ def get_keys(self, keys_to_fetch):
"""
Args:
- server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
- list of (server_name, iterable[key_id]) tuples to fetch keys for
- Note that the iterables may be iterated more than once.
+ keys_to_fetch (dict[str, dict[str, int]]):
+ the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
@@ -378,13 +432,15 @@ class StoreKeyFetcher(KeyFetcher):
self.store = hs.get_datastore()
@defer.inlineCallbacks
- def get_keys(self, server_name_and_key_ids):
+ def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
+
keys_to_fetch = (
(server_name, key_id)
- for server_name, key_ids in server_name_and_key_ids
- for key_id in key_ids
+ for server_name, keys_for_server in keys_to_fetch.items()
+ for key_id in keys_for_server.keys()
)
+
res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
@@ -399,7 +455,7 @@ class BaseV2KeyFetcher(object):
@defer.inlineCallbacks
def process_v2_response(
- self, from_server, response_json, time_added_ms, requested_ids=[]
+ self, from_server, response_json, time_added_ms
):
"""Parse a 'Server Keys' structure from the result of a /key request
@@ -422,10 +478,6 @@ class BaseV2KeyFetcher(object):
time_added_ms (int): the timestamp to record in server_keys_json
- requested_ids (iterable[str]): a list of the key IDs that were requested.
- We will store the json for these key ids as well as any that are
- actually in the response
-
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
@@ -481,11 +533,6 @@ class BaseV2KeyFetcher(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json)
- # for reasons I don't quite understand, we store this json for the key ids we
- # requested, as well as those we got.
- updated_key_ids = set(requested_ids)
- updated_key_ids.update(verify_keys)
-
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
@@ -498,7 +545,7 @@ class BaseV2KeyFetcher(object):
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
- for key_id in updated_key_ids
+ for key_id in verify_keys
],
consumeErrors=True,
).addErrback(unwrapFirstError)
@@ -517,14 +564,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.perspective_servers = self.config.perspectives
@defer.inlineCallbacks
- def get_keys(self, server_name_and_key_ids):
+ def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
- server_name_and_key_ids, perspective_name, perspective_keys
+ keys_to_fetch, perspective_name, perspective_keys
)
defer.returnValue(result)
except KeyLookupError as e:
@@ -558,13 +605,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(
- self, server_names_and_key_ids, perspective_name, perspective_keys
+ self, keys_to_fetch, perspective_name, perspective_keys
):
"""
Args:
- server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
- list of (server_name, iterable[key_id]) tuples to fetch keys for
+ keys_to_fetch (dict[str, dict[str, int]]):
+ the keys to be fetched. server_name -> key_id -> min_valid_ts
+
perspective_name (str): name of the notary server to query for the keys
+
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
@@ -578,12 +627,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""
logger.info(
"Requesting keys %s from notary server %s",
- server_names_and_key_ids,
+ keys_to_fetch.items(),
perspective_name,
)
- # TODO(mark): Set the minimum_valid_until_ts to that needed by
- # the events being validated or the current time if validating
- # an incoming request.
+
try:
query_response = yield self.client.post_json(
destination=perspective_name,
@@ -591,12 +638,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
data={
u"server_keys": {
server_name: {
- key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
+ key_id: {u"minimum_valid_until_ts": min_valid_ts}
+ for key_id, min_valid_ts in server_keys.items()
}
- for server_name, key_ids in server_names_and_key_ids
+ for server_name, server_keys in keys_to_fetch.items()
}
},
- long_retries=True,
)
except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e)
@@ -702,34 +749,54 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock()
self.client = hs.get_http_client()
- @defer.inlineCallbacks
- def get_keys(self, server_name_and_key_ids):
- """see KeyFetcher.get_keys"""
- results = yield logcontext.make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.get_server_verify_key_v2_direct, server_name, key_ids
- )
- for server_name, key_ids in server_name_and_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
- )
+ def get_keys(self, keys_to_fetch):
+ """
+ Args:
+ keys_to_fetch (dict[str, iterable[str]]):
+ the keys to be fetched. server_name -> key_ids
- merged = {}
- for result in results:
- merged.update(result)
+ Returns:
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+ map from server_name -> key_id -> FetchKeyResult
+ """
+
+ results = {}
- defer.returnValue(
- {server_name: keys for server_name, keys in merged.items() if keys}
+ @defer.inlineCallbacks
+ def get_key(key_to_fetch_item):
+ server_name, key_ids = key_to_fetch_item
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+ results[server_name] = keys
+ except KeyLookupError as e:
+ logger.warning(
+ "Error looking up keys %s from %s: %s", key_ids, server_name, e
+ )
+ except Exception:
+ logger.exception("Error getting keys %s from %s", key_ids, server_name)
+
+ return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
+ lambda _: results
)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ """
+
+ Args:
+ server_name (str):
+ key_ids (iterable[str]):
+
+ Returns:
+ Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+
+ Raises:
+ KeyLookupError if there was a problem making the lookup
+ """
keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids:
+ # we may have found this key as a side-effect of asking for another.
if requested_key_id in keys:
continue
@@ -740,6 +807,19 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
ignore_backoff=True,
+
+ # we only give the remote server 10s to respond. It should be an
+ # easy request to handle, so if it doesn't reply within 10s, it's
+ # probably not going to.
+ #
+ # Furthermore, when we are acting as a notary server, we cannot
+ # wait all day for all of the origin servers, as the requesting
+ # server will otherwise time out before we can respond.
+ #
+ # (Note that get_json may make 4 attempts, so this can still take
+ # almost 45 seconds to fetch the headers, plus up to another 60s to
+ # read the response).
+ timeout=10000,
)
except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e)
@@ -754,7 +834,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
response_keys = yield self.process_v2_response(
from_server=server_name,
- requested_ids=[requested_key_id],
response_json=response,
time_added_ms=time_now_ms,
)
@@ -765,7 +844,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
)
keys.update(response_keys)
- defer.returnValue({server_name: keys})
+ defer.returnValue(keys)
@defer.inlineCallbacks
@@ -783,14 +862,10 @@ def _handle_key_deferred(verify_request):
"""
server_name = verify_request.server_name
with PreserveLoggingContext():
- _, key_id, verify_key = yield verify_request.deferred
+ _, key_id, verify_key = yield verify_request.key_ready
json_object = verify_request.json_object
- logger.debug(
- "Got key %s %s:%s for server %s, verifying"
- % (key_id, verify_key.alg, verify_key.version, server_name)
- )
try:
verify_signed_json(json_object, server_name, verify_key)
except SignatureVerifyException as e:
|