diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 4911f0896b..24f15f3154 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -18,7 +18,9 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import (
+ preserve_context_over_fn, preserve_context_over_deferred
+)
import simplejson as json
import logging
@@ -40,11 +42,14 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5):
try:
- with PreserveLoggingContext():
- protocol = yield endpoint.connect(factory)
- server_response, server_certificate = yield protocol.remote_key
- defer.returnValue((server_response, server_certificate))
- return
+ protocol = yield preserve_context_over_fn(
+ endpoint.connect, factory
+ )
+ server_response, server_certificate = yield preserve_context_over_deferred(
+ protocol.remote_key
+ )
+ defer.returnValue((server_response, server_certificate))
+ return
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 8709394b97..aff69c5f83 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
-from synapse.util.async import create_observer
+from synapse.util.async import ObservableDeferred
from OpenSSL import crypto
@@ -111,6 +111,10 @@ class Keyring(object):
if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids)
+ download = ObservableDeferred(
+ download,
+ consumeErrors=True
+ )
self.key_downloads[server_name] = download
@download.addBoth
@@ -118,30 +122,31 @@ class Keyring(object):
del self.key_downloads[server_name]
return ret
- r = yield create_observer(download)
+ r = yield download.observe()
defer.returnValue(r)
@defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids):
keys = None
- perspective_results = []
- for perspective_name, perspective_keys in self.perspective_servers.items():
- @defer.inlineCallbacks
- def get_key():
- try:
- result = yield self.get_server_verify_key_v2_indirect(
- server_name, key_ids, perspective_name, perspective_keys
- )
- defer.returnValue(result)
- except:
- logging.info(
- "Unable to getting key %r for %r from %r",
- key_ids, server_name, perspective_name,
- )
- perspective_results.append(get_key())
+ @defer.inlineCallbacks
+ def get_key(perspective_name, perspective_keys):
+ try:
+ result = yield self.get_server_verify_key_v2_indirect(
+ server_name, key_ids, perspective_name, perspective_keys
+ )
+ defer.returnValue(result)
+ except Exception as e:
+ logging.info(
+ "Unable to getting key %r for %r from %r: %s %s",
+ key_ids, server_name, perspective_name,
+ type(e).__name__, str(e.message),
+ )
- perspective_results = yield defer.gatherResults(perspective_results)
+ perspective_results = yield defer.gatherResults([
+ get_key(p_name, p_keys)
+ for p_name, p_keys in self.perspective_servers.items()
+ ])
for results in perspective_results:
if results is not None:
@@ -154,17 +159,22 @@ class Keyring(object):
)
with limiter:
- if keys is None:
+ if not keys:
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
- except:
- pass
+ except Exception as e:
+ logging.info(
+ "Unable to getting key %r for %r directly: %s %s",
+ key_ids, server_name,
+ type(e).__name__, str(e.message),
+ )
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
+ if not keys:
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
for key_id in key_ids:
if key_id in keys:
@@ -184,7 +194,7 @@ class Keyring(object):
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
- responses = yield self.client.post_json(
+ query_response = yield self.client.post_json(
destination=perspective_name,
path=b"/_matrix/key/v2/query",
data={
@@ -200,6 +210,8 @@ class Keyring(object):
keys = {}
+ responses = query_response["server_keys"]
+
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
@@ -323,7 +335,7 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
- for key_id in response_json["signatures"][server_name]:
+ for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"
|