diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 1fea568eed..d08ee0aa91 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import (
+ preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
+ preserve_fn
+)
from twisted.internet import defer
@@ -142,40 +146,43 @@ class Keyring(object):
for server_name, _ in server_and_json
}
- # We want to wait for any previous lookups to complete before
- # proceeding.
- wait_on_deferred = self.wait_for_previous_lookups(
- [server_name for server_name, _ in server_and_json],
- server_to_deferred,
- )
+ with PreserveLoggingContext():
- # Actually start fetching keys.
- wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
- )
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ wait_on_deferred = self.wait_for_previous_lookups(
+ [server_name for server_name, _ in server_and_json],
+ server_to_deferred,
+ )
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- server_to_gids = {}
+ # Actually start fetching keys.
+ wait_on_deferred.addBoth(
+ lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+ )
+
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ server_to_gids = {}
- def remove_deferreds(res, server_name, group_id):
- server_to_gids[server_name].discard(group_id)
- if not server_to_gids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
- return res
+ def remove_deferreds(res, server_name, group_id):
+ server_to_gids[server_name].discard(group_id)
+ if not server_to_gids[server_name]:
+ d = server_to_deferred.pop(server_name, None)
+ if d:
+ d.callback(None)
+ return res
- for g_id, deferred in deferreds.items():
- server_name = group_id_to_group[g_id].server_name
- server_to_gids.setdefault(server_name, set()).add(g_id)
- deferred.addBoth(remove_deferreds, server_name, g_id)
+ for g_id, deferred in deferreds.items():
+ server_name = group_id_to_group[g_id].server_name
+ server_to_gids.setdefault(server_name, set()).add(g_id)
+ deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
- handle_key_deferred(
+ preserve_context_over_fn(
+ handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
@@ -198,12 +205,13 @@ class Keyring(object):
if server_name in self.key_downloads
]
if wait_on:
- yield defer.DeferredList(wait_on)
+ with PreserveLoggingContext():
+ yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred.items():
- d = ObservableDeferred(deferred)
+ d = ObservableDeferred(preserve_context_over_deferred(deferred))
self.key_downloads[server_name] = d
def rm(r, server_name):
@@ -244,12 +252,13 @@ class Keyring(object):
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
- group_id_to_deferred[group.group_id].callback((
- group.group_id,
- group.server_name,
- key_id,
- merged_results[group.server_name][key_id],
- ))
+ with PreserveLoggingContext():
+ group_id_to_deferred[group.group_id].callback((
+ group.group_id,
+ group.server_name,
+ key_id,
+ merged_results[group.server_name][key_id],
+ ))
break
else:
missing_groups.setdefault(
@@ -504,7 +513,7 @@ class Keyring(object):
yield defer.gatherResults(
[
- self.store_keys(
+ preserve_fn(self.store_keys)(
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@@ -573,7 +582,7 @@ class Keyring(object):
yield defer.gatherResults(
[
- self.store.store_server_keys_json(
+ preserve_fn(self.store.store_server_keys_json)(
server_name=server_name,
key_id=key_id,
from_server=server_name,
@@ -675,7 +684,7 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults(
[
- self.store.store_server_verify_key(
+ preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
|