summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2017-09-20 01:32:42 +0100
committerRichard van der Hoff <richard@matrix.org>2017-09-20 01:32:42 +0100
commitc5b0e9f48542516a4fa82247c81e499894340cf5 (patch)
treeedbecbb4157ac30165a2d6bef434091fcaaed45d /synapse/crypto/keyring.py
parentFix potential race in _start_key_lookups (diff)
downloadsynapse-c5b0e9f48542516a4fa82247c81e499894340cf5.tar.xz
Turn _start_key_lookups into an inlineCallbacks function
... which means that logcontexts can be correctly preserved for the stuff it
does.

get_server_verify_keys is now called with the logcontext, so needs to
preserve_fn when it fires off its nested inlineCallbacks function.

Also renames get_server_verify_keys to reflect the fact it's meant to be
private.
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py77
1 files changed, 37 insertions, 40 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 0e381c4710..7e4cef13c1 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -123,7 +123,7 @@ class Keyring(object):
 
             verify_requests.append(verify_request)
 
-        self._start_key_lookups(verify_requests)
+        preserve_fn(self._start_key_lookups)(verify_requests)
 
         # Pass those keys to handle_key_deferred so that the json object
         # signatures can be verified
@@ -132,6 +132,7 @@ class Keyring(object):
             for rq in verify_requests
         ]
 
+    @defer.inlineCallbacks
     def _start_key_lookups(self, verify_requests):
         """Sets off the key fetches for each verify request
 
@@ -151,47 +152,43 @@ class Keyring(object):
             for rq in verify_requests
         }
 
-        with PreserveLoggingContext():
+        # We want to wait for any previous lookups to complete before
+        # proceeding.
+        yield self.wait_for_previous_lookups(
+            [rq.server_name for rq in verify_requests],
+            server_to_deferred,
+        )
 
-            # We want to wait for any previous lookups to complete before
-            # proceeding.
-            wait_on_deferred = self.wait_for_previous_lookups(
-                [rq.server_name for rq in verify_requests],
-                server_to_deferred,
-            )
+        # Actually start fetching keys.
+        self._get_server_verify_keys(verify_requests)
 
-            # Actually start fetching keys.
-            wait_on_deferred.addBoth(
-                lambda _: self.get_server_verify_keys(verify_requests)
+        # When we've finished fetching all the keys for a given server_name,
+        # resolve the deferred passed to `wait_for_previous_lookups` so that
+        # any lookups waiting will proceed.
+        #
+        # map from server name to a set of request ids
+        server_to_request_ids = {}
+
+        for verify_request in verify_requests:
+            server_name = verify_request.server_name
+            request_id = id(verify_request)
+            server_to_request_ids.setdefault(server_name, set()).add(request_id)
+
+        def remove_deferreds(res, verify_request):
+            server_name = verify_request.server_name
+            request_id = id(verify_request)
+            server_to_request_ids[server_name].discard(request_id)
+            if not server_to_request_ids[server_name]:
+                d = server_to_deferred.pop(server_name, None)
+                if d:
+                    d.callback(None)
+            return res
+
+        for verify_request in verify_requests:
+            verify_request.deferred.addBoth(
+                remove_deferreds, verify_request,
             )
 
-            # When we've finished fetching all the keys for a given server_name,
-            # resolve the deferred passed to `wait_for_previous_lookups` so that
-            # any lookups waiting will proceed.
-            #
-            # map from server name to a set of request ids
-            server_to_request_ids = {}
-
-            for verify_request in verify_requests:
-                server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids.setdefault(server_name, set()).add(request_id)
-
-            def remove_deferreds(res, verify_request):
-                server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids[server_name].discard(request_id)
-                if not server_to_request_ids[server_name]:
-                    d = server_to_deferred.pop(server_name, None)
-                    if d:
-                        d.callback(None)
-                return res
-
-            for verify_request in verify_requests:
-                verify_request.deferred.addBoth(
-                    remove_deferreds, verify_request,
-                )
-
     @defer.inlineCallbacks
     def wait_for_previous_lookups(self, server_names, server_to_deferred):
         """Waits for any previous key lookups for the given servers to finish.
@@ -227,7 +224,7 @@ class Keyring(object):
             self.key_downloads[server_name] = deferred
             deferred.addBoth(rm, server_name)
 
-    def get_server_verify_keys(self, verify_requests):
+    def _get_server_verify_keys(self, verify_requests):
         """Tries to find at least one key for each verify request
 
         For each verify_request, verify_request.deferred is called back with
@@ -312,7 +309,7 @@ class Keyring(object):
                     if not verify_request.deferred.called:
                         verify_request.deferred.errback(err)
 
-        do_iterations().addErrback(on_err)
+        preserve_fn(do_iterations)().addErrback(on_err)
 
     @defer.inlineCallbacks
     def get_keys_from_store(self, server_name_and_key_ids):