summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py64
1 files changed, 36 insertions, 28 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 1bb27edc0f..51851d04e5 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,10 +16,9 @@
 
 from synapse.crypto.keyclient import fetch_server_key
 from synapse.api.errors import SynapseError, Codes
-from synapse.util import unwrapFirstError
-from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.logcontext import (
-    preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
+    preserve_context_over_fn, PreserveLoggingContext,
     preserve_fn
 )
 from synapse.util.metrics import Measure
@@ -74,6 +74,11 @@ class Keyring(object):
         self.perspective_servers = self.config.perspectives
         self.hs = hs
 
+        # map from server name to Deferred. Has an entry for each server with
+        # an ongoing key download; the Deferred completes once the download
+        # completes.
+        #
+        # These are regular, logcontext-agnostic Deferreds.
         self.key_downloads = {}
 
     def verify_json_for_server(self, server_name, json_object):
@@ -82,7 +87,7 @@ class Keyring(object):
         )[0]
 
     def verify_json_objects_for_server(self, server_and_json):
-        """Bulk verfies signatures of json objects, bulk fetching keys as
+        """Bulk verifies signatures of json objects, bulk fetching keys as
         necessary.
 
         Args:
@@ -212,7 +217,13 @@ class Keyring(object):
         Args:
             server_names (list): list of server_names we want to lookup
             server_to_deferred (dict): server_name to deferred which gets
-                resolved once we've finished looking up keys for that server
+                resolved once we've finished looking up keys for that server.
+                The Deferreds should be regular twisted ones which call their
+                callbacks with no logcontext.
+
+        Returns: a Deferred which resolves once all key lookups for the given
+            servers have completed. Follows the synapse rules of logcontext
+            preservation.
         """
         while True:
             wait_on = [
@@ -226,15 +237,13 @@ class Keyring(object):
             else:
                 break
 
-        for server_name, deferred in server_to_deferred.items():
-            d = ObservableDeferred(preserve_context_over_deferred(deferred))
-            self.key_downloads[server_name] = d
-
-            def rm(r, server_name):
-                self.key_downloads.pop(server_name, None)
-                return r
+        def rm(r, server_name_):
+            self.key_downloads.pop(server_name_, None)
+            return r
 
-            d.addBoth(rm, server_name)
+        for server_name, deferred in server_to_deferred.items():
+            self.key_downloads[server_name] = deferred
+            deferred.addBoth(rm, server_name)
 
     def get_server_verify_keys(self, verify_requests):
         """Tries to find at least one key for each verify request
@@ -333,7 +342,7 @@ class Keyring(object):
             Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
                 server_name -> key_id -> VerifyKey
         """
-        res = yield preserve_context_over_deferred(defer.gatherResults(
+        res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
                 preserve_fn(self.store.get_server_verify_keys)(
                     server_name, key_ids
@@ -341,7 +350,7 @@ class Keyring(object):
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(dict(res))
 
@@ -362,13 +371,13 @@ class Keyring(object):
                 )
                 defer.returnValue({})
 
-        results = yield preserve_context_over_deferred(defer.gatherResults(
+        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
                 preserve_fn(get_key)(p_name, p_keys)
                 for p_name, p_keys in self.perspective_servers.items()
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         union_of_keys = {}
         for result in results:
@@ -402,13 +411,13 @@ class Keyring(object):
 
             defer.returnValue(keys)
 
-        results = yield preserve_context_over_deferred(defer.gatherResults(
+        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
                 preserve_fn(get_key)(server_name, key_ids)
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         merged = {}
         for result in results:
@@ -485,7 +494,7 @@ class Keyring(object):
             for server_name, response_keys in processed_response.items():
                 keys.setdefault(server_name, {}).update(response_keys)
 
-        yield preserve_context_over_deferred(defer.gatherResults(
+        yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
                 preserve_fn(self.store_keys)(
                     server_name=server_name,
@@ -495,7 +504,7 @@ class Keyring(object):
                 for server_name, response_keys in keys.items()
             ],
             consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(keys)
 
@@ -543,7 +552,7 @@ class Keyring(object):
 
             keys.update(response_keys)
 
-        yield preserve_context_over_deferred(defer.gatherResults(
+        yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
                 preserve_fn(self.store_keys)(
                     server_name=key_server_name,
@@ -553,7 +562,7 @@ class Keyring(object):
                 for key_server_name, verify_keys in keys.items()
             ],
             consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(keys)
 
@@ -619,7 +628,7 @@ class Keyring(object):
         response_keys.update(verify_keys)
         response_keys.update(old_verify_keys)
 
-        yield preserve_context_over_deferred(defer.gatherResults(
+        yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
                 preserve_fn(self.store.store_server_keys_json)(
                     server_name=server_name,
@@ -632,7 +641,7 @@ class Keyring(object):
                 for key_id in updated_key_ids
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         results[server_name] = response_keys
 
@@ -710,7 +719,6 @@ class Keyring(object):
 
         defer.returnValue(verify_keys)
 
-    @defer.inlineCallbacks
     def store_keys(self, server_name, from_server, verify_keys):
         """Store a collection of verify keys for a given server
         Args:
@@ -721,7 +729,7 @@ class Keyring(object):
             A deferred that completes when the keys are stored.
         """
         # TODO(markjh): Store whether the keys have expired.
-        yield preserve_context_over_deferred(defer.gatherResults(
+        return logcontext.make_deferred_yieldable(defer.gatherResults(
             [
                 preserve_fn(self.store.store_server_verify_key)(
                     server_name, server_name, key.time_added, key
@@ -729,4 +737,4 @@ class Keyring(object):
                 for key_id, key in verify_keys.items()
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))