summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py106
1 files changed, 48 insertions, 58 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 341c863152..6c3e885e72 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -238,27 +238,9 @@ class Keyring(object):
         """
 
         try:
-            # create a deferred for each server we're going to look up the keys
-            # for; we'll resolve them once we have completed our lookups.
-            # These will be passed into wait_for_previous_lookups to block
-            # any other lookups until we have finished.
-            # The deferreds are called with no logcontext.
-            server_to_deferred = {
-                rq.server_name: defer.Deferred() for rq in verify_requests
-            }
-
-            # We want to wait for any previous lookups to complete before
-            # proceeding.
-            yield self.wait_for_previous_lookups(server_to_deferred)
+            ctx = LoggingContext.current_context()
 
-            # Actually start fetching keys.
-            self._get_server_verify_keys(verify_requests)
-
-            # When we've finished fetching all the keys for a given server_name,
-            # resolve the deferred passed to `wait_for_previous_lookups` so that
-            # any lookups waiting will proceed.
-            #
-            # map from server name to a set of request ids
+            # map from server name to a set of outstanding request ids
             server_to_request_ids = {}
 
             for verify_request in verify_requests:
@@ -266,40 +248,61 @@ class Keyring(object):
                 request_id = id(verify_request)
                 server_to_request_ids.setdefault(server_name, set()).add(request_id)
 
-            def remove_deferreds(res, verify_request):
+            # Wait for any previous lookups to complete before proceeding.
+            yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+
+            # take out a lock on each of the servers by sticking a Deferred in
+            # key_downloads
+            for server_name in server_to_request_ids.keys():
+                self.key_downloads[server_name] = defer.Deferred()
+                logger.debug("Got key lookup lock on %s", server_name)
+
+            # When we've finished fetching all the keys for a given server_name,
+            # drop the lock by resolving the deferred in key_downloads.
+            def drop_server_lock(server_name):
+                d = self.key_downloads.pop(server_name)
+                d.callback(None)
+
+            def lookup_done(res, verify_request):
                 server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids[server_name].discard(request_id)
-                if not server_to_request_ids[server_name]:
-                    d = server_to_deferred.pop(server_name, None)
-                    if d:
-                        d.callback(None)
+                server_requests = server_to_request_ids[server_name]
+                server_requests.remove(id(verify_request))
+
+                # if there are no more requests for this server, we can drop the lock.
+                if not server_requests:
+                    with PreserveLoggingContext(ctx):
+                        logger.debug("Releasing key lookup lock on %s", server_name)
+
+                    # ... but not immediately, as that can cause stack explosions if
+                    # we get a long queue of lookups.
+                    self.clock.call_later(0, drop_server_lock, server_name)
+
                 return res
 
             for verify_request in verify_requests:
-                verify_request.key_ready.addBoth(remove_deferreds, verify_request)
+                verify_request.key_ready.addBoth(lookup_done, verify_request)
+
+            # Actually start fetching keys.
+            self._get_server_verify_keys(verify_requests)
         except Exception:
             logger.exception("Error starting key lookups")
 
     @defer.inlineCallbacks
-    def wait_for_previous_lookups(self, server_to_deferred):
+    def wait_for_previous_lookups(self, server_names):
         """Waits for any previous key lookups for the given servers to finish.
 
         Args:
-            server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
-                resolved once we've finished looking up keys for that server.
-                The Deferreds should be regular twisted ones which call their
-                callbacks with no logcontext.
-
-        Returns: a Deferred which resolves once all key lookups for the given
-            servers have completed. Follows the synapse rules of logcontext
-            preservation.
+            server_names (Iterable[str]): list of servers which we want to look up
+
+        Returns:
+            Deferred[None]: resolves once all key lookups for the given servers have
+                completed. Follows the synapse rules of logcontext preservation.
         """
         loop_count = 1
         while True:
             wait_on = [
                 (server_name, self.key_downloads[server_name])
-                for server_name in server_to_deferred.keys()
+                for server_name in server_names
                 if server_name in self.key_downloads
             ]
             if not wait_on:
@@ -314,19 +317,6 @@ class Keyring(object):
 
             loop_count += 1
 
-        ctx = LoggingContext.current_context()
-
-        def rm(r, server_name_):
-            with PreserveLoggingContext(ctx):
-                logger.debug("Releasing key lookup lock on %s", server_name_)
-                self.key_downloads.pop(server_name_, None)
-            return r
-
-        for server_name, deferred in server_to_deferred.items():
-            logger.debug("Got key lookup lock on %s", server_name)
-            self.key_downloads[server_name] = deferred
-            deferred.addBoth(rm, server_name)
-
     def _get_server_verify_keys(self, verify_requests):
         """Tries to find at least one key for each verify request
 
@@ -472,7 +462,7 @@ class StoreKeyFetcher(KeyFetcher):
         keys = {}
         for (server_name, key_id), key in res.items():
             keys.setdefault(server_name, {})[key_id] = key
-        defer.returnValue(keys)
+        return keys
 
 
 class BaseV2KeyFetcher(object):
@@ -576,7 +566,7 @@ class BaseV2KeyFetcher(object):
             ).addErrback(unwrapFirstError)
         )
 
-        defer.returnValue(verify_keys)
+        return verify_keys
 
 
 class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@@ -598,7 +588,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 result = yield self.get_server_verify_key_v2_indirect(
                     keys_to_fetch, key_server
                 )
-                defer.returnValue(result)
+                return result
             except KeyLookupError as e:
                 logger.warning(
                     "Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,7 +601,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                     str(e),
                 )
 
-            defer.returnValue({})
+            return {}
 
         results = yield make_deferred_yieldable(
             defer.gatherResults(
@@ -625,7 +615,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             for server_name, keys in result.items():
                 union_of_keys.setdefault(server_name, {}).update(keys)
 
-        defer.returnValue(union_of_keys)
+        return union_of_keys
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
@@ -711,7 +701,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             perspective_name, time_now_ms, added_keys
         )
 
-        defer.returnValue(keys)
+        return keys
 
     def _validate_perspectives_response(self, key_server, response):
         """Optionally check the signature on the result of a /key/query request
@@ -853,7 +843,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             )
             keys.update(response_keys)
 
-        defer.returnValue(keys)
+        return keys
 
 
 @defer.inlineCallbacks