diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d7211ee9b3..22ee0fc93f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,12 +16,11 @@
from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes
-from synapse.util.retryutils import get_retry_limiter
-from synapse.util import unwrapFirstError
-from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import (
- preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
- preserve_fn
+ PreserveLoggingContext,
+ preserve_fn,
+ run_in_background,
)
from synapse.util.metrics import Measure
@@ -58,7 +58,8 @@ Attributes:
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
- a verify key has been fetched
+ a verify key has been fetched. The deferreds' callbacks are run with no
+ logcontext.
"""
@@ -75,31 +76,41 @@ class Keyring(object):
self.perspective_servers = self.config.perspectives
self.hs = hs
+ # map from server name to Deferred. Has an entry for each server with
+ # an ongoing key download; the Deferred completes once the download
+ # completes.
+ #
+ # These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object):
- return self.verify_json_objects_for_server(
- [(server_name, json_object)]
- )[0]
+ return logcontext.make_deferred_yieldable(
+ self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
+ )
def verify_json_objects_for_server(self, server_and_json):
- """Bulk verfies signatures of json objects, bulk fetching keys as
+ """Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json (list): List of pairs of (server_name, json_object)
Returns:
- list of deferreds indicating success or failure to verify each
- json object's signature for the given server_name.
+ List<Deferred>: for each input pair, a deferred indicating success
+ or failure to verify each json object's signature for the given
+ server_name. The deferreds run their callbacks in the sentinel
+ logcontext.
"""
verify_requests = []
for server_name, json_object in server_and_json:
- logger.debug("Verifying for %s", server_name)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
+ logger.warn("Request from %s: no supported signature keys",
+ server_name)
deferred = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
@@ -108,76 +119,69 @@ class Keyring(object):
else:
deferred = defer.Deferred()
+ logger.debug("Verifying for %s with key_ids %s",
+ server_name, key_ids)
+
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred
)
verify_requests.append(verify_request)
- @defer.inlineCallbacks
- def handle_key_deferred(verify_request):
- server_name = verify_request.server_name
- try:
- _, key_id, verify_key = yield verify_request.deferred
- except IOError as e:
- logger.warn(
- "Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.exception(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, key_ids),
- Codes.UNAUTHORIZED,
- )
+ run_in_background(self._start_key_lookups, verify_requests)
- json_object = verify_request.json_object
+ # Pass those keys to handle_key_deferred so that the json object
+ # signatures can be verified
+ handle = preserve_fn(_handle_key_deferred)
+ return [
+ handle(rq) for rq in verify_requests
+ ]
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except:
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
- ),
- Codes.UNAUTHORIZED,
- )
+ @defer.inlineCallbacks
+ def _start_key_lookups(self, verify_requests):
+ """Sets off the key fetches for each verify request
- server_to_deferred = {
- server_name: defer.Deferred()
- for server_name, _ in server_and_json
- }
+ Once each fetch completes, verify_request.deferred will be resolved.
- with PreserveLoggingContext():
+ Args:
+ verify_requests (List[VerifyKeyRequest]):
+ """
+
+ try:
+ # create a deferred for each server we're going to look up the keys
+ # for; we'll resolve them once we have completed our lookups.
+ # These will be passed into wait_for_previous_lookups to block
+ # any other lookups until we have finished.
+ # The deferreds are called with no logcontext.
+ server_to_deferred = {
+ rq.server_name: defer.Deferred()
+ for rq in verify_requests
+ }
# We want to wait for any previous lookups to complete before
# proceeding.
- wait_on_deferred = self.wait_for_previous_lookups(
- [server_name for server_name, _ in server_and_json],
+ yield self.wait_for_previous_lookups(
+ [rq.server_name for rq in verify_requests],
server_to_deferred,
)
# Actually start fetching keys.
- wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(verify_requests)
- )
+ self._get_server_verify_keys(verify_requests)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
+ #
+ # map from server name to a set of request ids
server_to_request_ids = {}
- def remove_deferreds(res, server_name, verify_request):
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
+
+ def remove_deferreds(res, verify_request):
+ server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
@@ -187,17 +191,11 @@ class Keyring(object):
return res
for verify_request in verify_requests:
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids.setdefault(server_name, set()).add(request_id)
- deferred.addBoth(remove_deferreds, server_name, verify_request)
-
- # Pass those keys to handle_key_deferred so that the json object
- # signatures can be verified
- return [
- preserve_context_over_fn(handle_key_deferred, verify_request)
- for verify_request in verify_requests
- ]
+ verify_request.deferred.addBoth(
+ remove_deferreds, verify_request,
+ )
+ except Exception:
+ logger.exception("Error starting key lookups")
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
@@ -206,7 +204,13 @@ class Keyring(object):
Args:
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
- resolved once we've finished looking up keys for that server
+ resolved once we've finished looking up keys for that server.
+ The Deferreds should be regular twisted ones which call their
+ callbacks with no logcontext.
+
+ Returns: a Deferred which resolves once all key lookups for the given
+ servers have completed. Follows the synapse rules of logcontext
+ preservation.
"""
while True:
wait_on = [
@@ -220,19 +224,23 @@ class Keyring(object):
else:
break
+ def rm(r, server_name_):
+ self.key_downloads.pop(server_name_, None)
+ return r
+
for server_name, deferred in server_to_deferred.items():
- d = ObservableDeferred(preserve_context_over_deferred(deferred))
- self.key_downloads[server_name] = d
+ self.key_downloads[server_name] = deferred
+ deferred.addBoth(rm, server_name)
- def rm(r, server_name):
- self.key_downloads.pop(server_name, None)
- return r
+ def _get_server_verify_keys(self, verify_requests):
+ """Tries to find at least one key for each verify request
- d.addBoth(rm, server_name)
+ For each verify_request, verify_request.deferred is called back with
+ params (server_name, key_id, VerifyKey) if a key is found, or errbacked
+ with a SynapseError if none of the keys are found.
- def get_server_verify_keys(self, verify_requests):
- """Takes a dict of KeyGroups and tries to find at least one key for
- each group.
+ Args:
+ verify_requests (list[VerifyKeyRequest]): list of verify requests
"""
# These are functions that produce keys given a list of key ids
@@ -245,8 +253,11 @@ class Keyring(object):
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
+ # dict[str, dict[str, VerifyKey]]: results so far.
+ # map server_name -> key_id -> VerifyKey
merged_results = {}
+ # dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
@@ -290,33 +301,46 @@ class Keyring(object):
if not missing_keys:
break
- for verify_request in requests_missing_keys.values():
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ with PreserveLoggingContext():
+ for verify_request in requests_missing_keys:
+ verify_request.deferred.errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ verify_request.server_name, verify_request.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
def on_err(err):
- for verify_request in verify_requests:
- if not verify_request.deferred.called:
- verify_request.deferred.errback(err)
+ with PreserveLoggingContext():
+ for verify_request in verify_requests:
+ if not verify_request.deferred.called:
+ verify_request.deferred.errback(err)
- do_iterations().addErrback(on_err)
+ run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
- res = yield preserve_context_over_deferred(defer.gatherResults(
+ """
+
+ Args:
+ server_name_and_key_ids (list[(str, iterable[str])]):
+ list of (server_name, iterable[key_id]) tuples to fetch keys for
+
+ Returns:
+ Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
+ server_name -> key_id -> VerifyKey
+ """
+ res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.get_server_verify_keys)(
- server_name, key_ids
+ run_in_background(
+ self.store.get_server_verify_keys,
+ server_name, key_ids,
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(dict(res))
@@ -333,17 +357,17 @@ class Keyring(object):
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
- type(e).__name__, str(e.message),
+ type(e).__name__, str(e),
)
defer.returnValue({})
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(p_name, p_keys)
+ run_in_background(get_key, p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
union_of_keys = {}
for result in results:
@@ -356,40 +380,34 @@ class Keyring(object):
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
- limiter = yield get_retry_limiter(
- server_name,
- self.clock,
- self.store,
- )
- with limiter:
- keys = None
- try:
- keys = yield self.get_server_verify_key_v2_direct(
- server_name, key_ids
- )
- except Exception as e:
- logger.info(
- "Unable to get key %r for %r directly: %s %s",
- key_ids, server_name,
- type(e).__name__, str(e.message),
- )
+ keys = None
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(
+ server_name, key_ids
+ )
+ except Exception as e:
+ logger.info(
+ "Unable to get key %r for %r directly: %s %s",
+ key_ids, server_name,
+ type(e).__name__, str(e),
+ )
- if not keys:
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
+ if not keys:
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
- keys = {server_name: keys}
+ keys = {server_name: keys}
defer.returnValue(keys)
- results = yield preserve_context_over_deferred(defer.gatherResults(
+ results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(server_name, key_ids)
+ run_in_background(get_key, server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
merged = {}
for result in results:
@@ -466,9 +484,10 @@ class Keyring(object):
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
- yield preserve_context_over_deferred(defer.gatherResults(
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
@@ -476,7 +495,7 @@ class Keyring(object):
for server_name, response_keys in keys.items()
],
consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(keys)
@@ -524,9 +543,10 @@ class Keyring(object):
keys.update(response_keys)
- yield preserve_context_over_deferred(defer.gatherResults(
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@@ -534,7 +554,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(keys)
@@ -600,9 +620,10 @@ class Keyring(object):
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
- yield preserve_context_over_deferred(defer.gatherResults(
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_keys_json)(
+ run_in_background(
+ self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=server_name,
@@ -613,7 +634,7 @@ class Keyring(object):
for key_id in updated_key_ids
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
results[server_name] = response_keys
@@ -691,7 +712,6 @@ class Keyring(object):
defer.returnValue(verify_keys)
- @defer.inlineCallbacks
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
@@ -702,12 +722,57 @@ class Keyring(object):
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
- yield preserve_context_over_deferred(defer.gatherResults(
+ return logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_verify_key)(
+ run_in_background(
+ self.store.store_server_verify_key,
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
+
+
+@defer.inlineCallbacks
+def _handle_key_deferred(verify_request):
+ server_name = verify_request.server_name
+ try:
+ with PreserveLoggingContext():
+ _, key_id, verify_key = yield verify_request.deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, verify_request.key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = verify_request.json_object
+
+ logger.debug("Got key %s %s:%s for server %s, verifying" % (
+ key_id, verify_key.alg, verify_key.version, server_name,
+ ))
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except Exception:
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s" % (
+ server_name, verify_key.alg, verify_key.version
+ ),
+ Codes.UNAUTHORIZED,
+ )
|