diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index ebf4e2e7a6..2a1d383078 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -18,7 +18,7 @@ from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes
from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import (
- preserve_context_over_fn, PreserveLoggingContext,
+ PreserveLoggingContext,
preserve_fn
)
from synapse.util.metrics import Measure
@@ -57,7 +57,8 @@ Attributes:
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
- a verify key has been fetched
+ a verify key has been fetched. The deferreds' callbacks are run with no
+ logcontext.
"""
@@ -82,9 +83,11 @@ class Keyring(object):
self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object):
- return self.verify_json_objects_for_server(
- [(server_name, json_object)]
- )[0]
+ return logcontext.make_deferred_yieldable(
+ self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
+ )
def verify_json_objects_for_server(self, server_and_json):
"""Bulk verifies signatures of json objects, bulk fetching keys as
@@ -94,8 +97,10 @@ class Keyring(object):
server_and_json (list): List of pairs of (server_name, json_object)
Returns:
- list of deferreds indicating success or failure to verify each
- json object's signature for the given server_name.
+ List<Deferred>: for each input pair, a deferred indicating success
+ or failure to verify each json object's signature for the given
+ server_name. The deferreds run their callbacks in the sentinel
+ logcontext.
"""
verify_requests = []
@@ -122,95 +127,71 @@ class Keyring(object):
verify_requests.append(verify_request)
- @defer.inlineCallbacks
- def handle_key_deferred(verify_request):
- server_name = verify_request.server_name
- try:
- _, key_id, verify_key = yield verify_request.deferred
- except IOError as e:
- logger.warn(
- "Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.exception(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, key_ids),
- Codes.UNAUTHORIZED,
- )
+ preserve_fn(self._start_key_lookups)(verify_requests)
- json_object = verify_request.json_object
+ # Pass those keys to handle_key_deferred so that the json object
+ # signatures can be verified
+ handle = preserve_fn(_handle_key_deferred)
+ return [
+ handle(rq) for rq in verify_requests
+ ]
- logger.debug("Got key %s %s:%s for server %s, verifying" % (
- key_id, verify_key.alg, verify_key.version, server_name,
- ))
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except:
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
- ),
- Codes.UNAUTHORIZED,
- )
+ @defer.inlineCallbacks
+ def _start_key_lookups(self, verify_requests):
+ """Sets off the key fetches for each verify request
+
+ Once each fetch completes, verify_request.deferred will be resolved.
+
+ Args:
+ verify_requests (List[VerifyKeyRequest]):
+ """
+ # create a deferred for each server we're going to look up the keys
+ # for; we'll resolve them once we have completed our lookups.
+ # These will be passed into wait_for_previous_lookups to block
+ # any other lookups until we have finished.
+ # The deferreds are called with no logcontext.
server_to_deferred = {
- server_name: defer.Deferred()
- for server_name, _ in server_and_json
+ rq.server_name: defer.Deferred()
+ for rq in verify_requests
}
- with PreserveLoggingContext():
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ yield self.wait_for_previous_lookups(
+ [rq.server_name for rq in verify_requests],
+ server_to_deferred,
+ )
- # We want to wait for any previous lookups to complete before
- # proceeding.
- wait_on_deferred = self.wait_for_previous_lookups(
- [server_name for server_name, _ in server_and_json],
- server_to_deferred,
- )
+ # Actually start fetching keys.
+ self._get_server_verify_keys(verify_requests)
- # Actually start fetching keys.
- wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(verify_requests)
- )
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ #
+ # map from server name to a set of request ids
+ server_to_request_ids = {}
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- server_to_request_ids = {}
-
- def remove_deferreds(res, server_name, verify_request):
- request_id = id(verify_request)
- server_to_request_ids[server_name].discard(request_id)
- if not server_to_request_ids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
- return res
-
- for verify_request in verify_requests:
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids.setdefault(server_name, set()).add(request_id)
- verify_request.deferred.addBoth(
- remove_deferreds, server_name, verify_request,
- )
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
- # Pass those keys to handle_key_deferred so that the json object
- # signatures can be verified
- return [
- preserve_context_over_fn(handle_key_deferred, verify_request)
- for verify_request in verify_requests
- ]
+ def remove_deferreds(res, verify_request):
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids[server_name].discard(request_id)
+ if not server_to_request_ids[server_name]:
+ d = server_to_deferred.pop(server_name, None)
+ if d:
+ d.callback(None)
+ return res
+
+ for verify_request in verify_requests:
+ verify_request.deferred.addBoth(
+ remove_deferreds, verify_request,
+ )
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
@@ -247,7 +228,7 @@ class Keyring(object):
self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name)
- def get_server_verify_keys(self, verify_requests):
+ def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
For each verify_request, verify_request.deferred is called back with
@@ -316,21 +297,23 @@ class Keyring(object):
if not missing_keys:
break
- for verify_request in requests_missing_keys.values():
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ with PreserveLoggingContext():
+ for verify_request in requests_missing_keys.values():
+ verify_request.deferred.errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ verify_request.server_name, verify_request.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
def on_err(err):
- for verify_request in verify_requests:
- if not verify_request.deferred.called:
- verify_request.deferred.errback(err)
+ with PreserveLoggingContext():
+ for verify_request in verify_requests:
+ if not verify_request.deferred.called:
+ verify_request.deferred.errback(err)
- do_iterations().addErrback(on_err)
+ preserve_fn(do_iterations)().addErrback(on_err)
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
@@ -740,3 +723,47 @@ class Keyring(object):
],
consumeErrors=True,
).addErrback(unwrapFirstError))
+
+
+@defer.inlineCallbacks
+def _handle_key_deferred(verify_request):
+ server_name = verify_request.server_name
+ try:
+ with PreserveLoggingContext():
+ _, key_id, verify_key = yield verify_request.deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, verify_request.key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = verify_request.json_object
+
+ logger.debug("Got key %s %s:%s for server %s, verifying" % (
+ key_id, verify_key.alg, verify_key.version, server_name,
+ ))
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except:
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s" % (
+ server_name, verify_key.alg, verify_key.version
+ ),
+ Codes.UNAUTHORIZED,
+ )
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 28eaab2cef..babd9ea078 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -18,8 +18,7 @@ from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import spamcheck
from synapse.events.utils import prune_event
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_context_over_deferred, preserve_fn
+from synapse.util import unwrapFirstError, logcontext
from twisted.internet import defer
logger = logging.getLogger(__name__)
@@ -51,56 +50,52 @@ class FederationBase(object):
"""
deferreds = self._check_sigs_and_hashes(pdus)
- def callback(pdu):
- return pdu
+ @defer.inlineCallbacks
+ def handle_check_result(pdu, deferred):
+ try:
+ res = yield logcontext.make_deferred_yieldable(deferred)
+ except SynapseError:
+ res = None
- def errback(failure, pdu):
- failure.trap(SynapseError)
- return None
-
- def try_local_db(res, pdu):
if not res:
# Check local db.
- return self.store.get_event(
+ res = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
- return res
- def try_remote(res, pdu):
if not res and pdu.origin != origin:
- return self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- outlier=outlier,
- timeout=10000,
- ).addErrback(lambda e: None)
- return res
-
- def warn(res, pdu):
+ try:
+ res = yield self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ outlier=outlier,
+ timeout=10000,
+ )
+ except SynapseError:
+ pass
+
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
- return res
- for pdu, deferred in zip(pdus, deferreds):
- deferred.addCallbacks(
- callback, errback, errbackArgs=[pdu]
- ).addCallback(
- try_local_db, pdu
- ).addCallback(
- try_remote, pdu
- ).addCallback(
- warn, pdu
- )
+ defer.returnValue(res)
- valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
- deferreds,
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ handle = logcontext.preserve_fn(handle_check_result)
+ deferreds2 = [
+ handle(pdu, deferred)
+ for pdu, deferred in zip(pdus, deferreds)
+ ]
+
+ valid_pdus = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ deferreds2,
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
if include_none:
defer.returnValue(valid_pdus)
@@ -108,7 +103,9 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu):
- return self._check_sigs_and_hashes([pdu])[0]
+ return logcontext.make_deferred_yieldable(
+ self._check_sigs_and_hashes([pdu])[0],
+ )
def _check_sigs_and_hashes(self, pdus):
"""Checks that each of the received events is correctly signed by the
@@ -123,6 +120,7 @@ class FederationBase(object):
* returns a redacted version of the event (if the signature
matched but the hash did not)
* throws a SynapseError if the signature check failed.
+ The deferreds run their callbacks in the sentinel logcontext.
"""
redacted_pdus = [
@@ -130,34 +128,38 @@ class FederationBase(object):
for pdu in pdus
]
- deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
+ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
- def callback(_, pdu, redacted):
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
- )
- return redacted
-
- if spamcheck.check_event_for_spam(pdu):
- logger.warn(
- "Event contains spam, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
- )
- return redacted
+ ctx = logcontext.LoggingContext.current_context()
- return pdu
+ def callback(_, pdu, redacted):
+ with logcontext.PreserveLoggingContext(ctx):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ if spamcheck.check_event_for_spam(pdu):
+ logger.warn(
+ "Event contains spam, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
- logger.warn(
- "Signature check failed for %s",
- pdu.event_id,
- )
+ with logcontext.PreserveLoggingContext(ctx):
+ logger.warn(
+ "Signature check failed for %s",
+ pdu.event_id,
+ )
return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 861441708b..7c5e5d957f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -22,7 +22,7 @@ from synapse.api.constants import Membership
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@@ -189,10 +189,10 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
+ pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(pdus)
@@ -252,7 +252,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+ signed_pdu = yield self._check_sigs_and_hash(pdu)
break
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 3b5e0a4fb9..87aeaf71d6 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -113,30 +113,37 @@ class KeyStore(SQLBaseStore):
keys[key_id] = key
defer.returnValue(keys)
- @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
Args:
server_name (str): The name of the server.
- key_id (str): The version of the key for the server.
from_server (str): Where the verification key was looked up
- ts_now_ms (int): The time now in milliseconds
- verification_key (VerifyKey): The NACL verify key.
+ time_now_ms (int): The time now in milliseconds
+ verify_key (nacl.signing.VerifyKey): The NACL verify key.
"""
- yield self._simple_upsert(
- table="server_signature_keys",
- keyvalues={
- "server_name": server_name,
- "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
- },
- values={
- "from_server": from_server,
- "ts_added_ms": time_now_ms,
- "verify_key": buffer(verify_key.encode()),
- },
- desc="store_server_verify_key",
- )
+ key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+ def _txn(txn):
+ self._simple_upsert_txn(
+ txn,
+ table="server_signature_keys",
+ keyvalues={
+ "server_name": server_name,
+ "key_id": key_id,
+ },
+ values={
+ "from_server": from_server,
+ "ts_added_ms": time_now_ms,
+ "verify_key": buffer(verify_key.encode()),
+ },
+ )
+ txn.call_after(
+ self._get_server_verify_key.invalidate,
+ (server_name, key_id)
+ )
+
+ return self.runInteraction("store_server_verify_key", _txn)
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
|