summary refs log tree commit diff
path: root/synapse/crypto
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto')
-rw-r--r--synapse/crypto/keyring.py92
1 files changed, 41 insertions, 51 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 341c863152..e8bb420ad1 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -238,27 +238,9 @@ class Keyring(object):
         """
 
         try:
-            # create a deferred for each server we're going to look up the keys
-            # for; we'll resolve them once we have completed our lookups.
-            # These will be passed into wait_for_previous_lookups to block
-            # any other lookups until we have finished.
-            # The deferreds are called with no logcontext.
-            server_to_deferred = {
-                rq.server_name: defer.Deferred() for rq in verify_requests
-            }
-
-            # We want to wait for any previous lookups to complete before
-            # proceeding.
-            yield self.wait_for_previous_lookups(server_to_deferred)
+            ctx = LoggingContext.current_context()
 
-            # Actually start fetching keys.
-            self._get_server_verify_keys(verify_requests)
-
-            # When we've finished fetching all the keys for a given server_name,
-            # resolve the deferred passed to `wait_for_previous_lookups` so that
-            # any lookups waiting will proceed.
-            #
-            # map from server name to a set of request ids
+            # map from server name to a set of outstanding request ids
             server_to_request_ids = {}
 
             for verify_request in verify_requests:
@@ -266,40 +248,61 @@ class Keyring(object):
                 request_id = id(verify_request)
                 server_to_request_ids.setdefault(server_name, set()).add(request_id)
 
-            def remove_deferreds(res, verify_request):
+            # Wait for any previous lookups to complete before proceeding.
+            yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+
+            # take out a lock on each of the servers by sticking a Deferred in
+            # key_downloads
+            for server_name in server_to_request_ids.keys():
+                self.key_downloads[server_name] = defer.Deferred()
+                logger.debug("Got key lookup lock on %s", server_name)
+
+            # When we've finished fetching all the keys for a given server_name,
+            # drop the lock by resolving the deferred in key_downloads.
+            def drop_server_lock(server_name):
+                d = self.key_downloads.pop(server_name)
+                d.callback(None)
+
+            def lookup_done(res, verify_request):
                 server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids[server_name].discard(request_id)
-                if not server_to_request_ids[server_name]:
-                    d = server_to_deferred.pop(server_name, None)
-                    if d:
-                        d.callback(None)
+                server_requests = server_to_request_ids[server_name]
+                server_requests.remove(id(verify_request))
+
+                # if there are no more requests for this server, we can drop the lock.
+                if not server_requests:
+                    with PreserveLoggingContext(ctx):
+                        logger.debug("Releasing key lookup lock on %s", server_name)
+
+                    # ... but not immediately, as that can cause stack explosions if
+                    # we get a long queue of lookups.
+                    self.clock.call_later(0, drop_server_lock, server_name)
+
                 return res
 
             for verify_request in verify_requests:
-                verify_request.key_ready.addBoth(remove_deferreds, verify_request)
+                verify_request.key_ready.addBoth(lookup_done, verify_request)
+
+            # Actually start fetching keys.
+            self._get_server_verify_keys(verify_requests)
         except Exception:
             logger.exception("Error starting key lookups")
 
     @defer.inlineCallbacks
-    def wait_for_previous_lookups(self, server_to_deferred):
+    def wait_for_previous_lookups(self, server_names):
         """Waits for any previous key lookups for the given servers to finish.
 
         Args:
-            server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
-                resolved once we've finished looking up keys for that server.
-                The Deferreds should be regular twisted ones which call their
-                callbacks with no logcontext.
-
-        Returns: a Deferred which resolves once all key lookups for the given
-            servers have completed. Follows the synapse rules of logcontext
-            preservation.
+            server_names (Iterable[str]): list of servers which we want to look up
+
+        Returns:
+            Deferred[None]: resolves once all key lookups for the given servers have
+                completed. Follows the synapse rules of logcontext preservation.
         """
         loop_count = 1
         while True:
             wait_on = [
                 (server_name, self.key_downloads[server_name])
-                for server_name in server_to_deferred.keys()
+                for server_name in server_names
                 if server_name in self.key_downloads
             ]
             if not wait_on:
@@ -314,19 +317,6 @@ class Keyring(object):
 
             loop_count += 1
 
-        ctx = LoggingContext.current_context()
-
-        def rm(r, server_name_):
-            with PreserveLoggingContext(ctx):
-                logger.debug("Releasing key lookup lock on %s", server_name_)
-                self.key_downloads.pop(server_name_, None)
-            return r
-
-        for server_name, deferred in server_to_deferred.items():
-            logger.debug("Got key lookup lock on %s", server_name)
-            self.key_downloads[server_name] = deferred
-            deferred.addBoth(rm, server_name)
-
     def _get_server_verify_keys(self, verify_requests):
         """Tries to find at least one key for each verify request