summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/5724.bugfix1
-rw-r--r--synapse/crypto/keyring.py92
-rw-r--r--tests/crypto/test_keyring.py29
3 files changed, 42 insertions, 80 deletions
diff --git a/changelog.d/5724.bugfix b/changelog.d/5724.bugfix
new file mode 100644
index 0000000000..1b3683daf6
--- /dev/null
+++ b/changelog.d/5724.bugfix
@@ -0,0 +1 @@
+Fix stack overflow in server key lookup code.
\ No newline at end of file
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 341c863152..e8bb420ad1 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -238,27 +238,9 @@ class Keyring(object):
         """
 
         try:
-            # create a deferred for each server we're going to look up the keys
-            # for; we'll resolve them once we have completed our lookups.
-            # These will be passed into wait_for_previous_lookups to block
-            # any other lookups until we have finished.
-            # The deferreds are called with no logcontext.
-            server_to_deferred = {
-                rq.server_name: defer.Deferred() for rq in verify_requests
-            }
-
-            # We want to wait for any previous lookups to complete before
-            # proceeding.
-            yield self.wait_for_previous_lookups(server_to_deferred)
+            ctx = LoggingContext.current_context()
 
-            # Actually start fetching keys.
-            self._get_server_verify_keys(verify_requests)
-
-            # When we've finished fetching all the keys for a given server_name,
-            # resolve the deferred passed to `wait_for_previous_lookups` so that
-            # any lookups waiting will proceed.
-            #
-            # map from server name to a set of request ids
+            # map from server name to a set of outstanding request ids
             server_to_request_ids = {}
 
             for verify_request in verify_requests:
@@ -266,40 +248,61 @@ class Keyring(object):
                 request_id = id(verify_request)
                 server_to_request_ids.setdefault(server_name, set()).add(request_id)
 
-            def remove_deferreds(res, verify_request):
+            # Wait for any previous lookups to complete before proceeding.
+            yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+
+            # take out a lock on each of the servers by sticking a Deferred in
+            # key_downloads
+            for server_name in server_to_request_ids.keys():
+                self.key_downloads[server_name] = defer.Deferred()
+                logger.debug("Got key lookup lock on %s", server_name)
+
+            # When we've finished fetching all the keys for a given server_name,
+            # drop the lock by resolving the deferred in key_downloads.
+            def drop_server_lock(server_name):
+                d = self.key_downloads.pop(server_name)
+                d.callback(None)
+
+            def lookup_done(res, verify_request):
                 server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids[server_name].discard(request_id)
-                if not server_to_request_ids[server_name]:
-                    d = server_to_deferred.pop(server_name, None)
-                    if d:
-                        d.callback(None)
+                server_requests = server_to_request_ids[server_name]
+                server_requests.remove(id(verify_request))
+
+                # if there are no more requests for this server, we can drop the lock.
+                if not server_requests:
+                    with PreserveLoggingContext(ctx):
+                        logger.debug("Releasing key lookup lock on %s", server_name)
+
+                    # ... but not immediately, as that can cause stack explosions if
+                    # we get a long queue of lookups.
+                    self.clock.call_later(0, drop_server_lock, server_name)
+
                 return res
 
             for verify_request in verify_requests:
-                verify_request.key_ready.addBoth(remove_deferreds, verify_request)
+                verify_request.key_ready.addBoth(lookup_done, verify_request)
+
+            # Actually start fetching keys.
+            self._get_server_verify_keys(verify_requests)
         except Exception:
             logger.exception("Error starting key lookups")
 
     @defer.inlineCallbacks
-    def wait_for_previous_lookups(self, server_to_deferred):
+    def wait_for_previous_lookups(self, server_names):
         """Waits for any previous key lookups for the given servers to finish.
 
         Args:
-            server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
-                resolved once we've finished looking up keys for that server.
-                The Deferreds should be regular twisted ones which call their
-                callbacks with no logcontext.
-
-        Returns: a Deferred which resolves once all key lookups for the given
-            servers have completed. Follows the synapse rules of logcontext
-            preservation.
+            server_names (Iterable[str]): list of servers which we want to look up
+
+        Returns:
+            Deferred[None]: resolves once all key lookups for the given servers have
+                completed. Follows the synapse rules of logcontext preservation.
         """
         loop_count = 1
         while True:
             wait_on = [
                 (server_name, self.key_downloads[server_name])
-                for server_name in server_to_deferred.keys()
+                for server_name in server_names
                 if server_name in self.key_downloads
             ]
             if not wait_on:
@@ -314,19 +317,6 @@ class Keyring(object):
 
             loop_count += 1
 
-        ctx = LoggingContext.current_context()
-
-        def rm(r, server_name_):
-            with PreserveLoggingContext(ctx):
-                logger.debug("Releasing key lookup lock on %s", server_name_)
-                self.key_downloads.pop(server_name_, None)
-            return r
-
-        for server_name, deferred in server_to_deferred.items():
-            logger.debug("Got key lookup lock on %s", server_name)
-            self.key_downloads[server_name] = deferred
-            deferred.addBoth(rm, server_name)
-
     def _get_server_verify_keys(self, verify_requests):
         """Tries to find at least one key for each verify request
 
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 795703967d..8d94a503d6 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -86,35 +86,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
             getattr(LoggingContext.current_context(), "request", None), expected
         )
 
-    def test_wait_for_previous_lookups(self):
-        kr = keyring.Keyring(self.hs)
-
-        lookup_1_deferred = defer.Deferred()
-        lookup_2_deferred = defer.Deferred()
-
-        # we run the lookup in a logcontext so that the patched inlineCallbacks can check
-        # it is doing the right thing with logcontexts.
-        wait_1_deferred = run_in_context(
-            kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
-        )
-
-        # there were no previous lookups, so the deferred should be ready
-        self.successResultOf(wait_1_deferred)
-
-        # set off another wait. It should block because the first lookup
-        # hasn't yet completed.
-        wait_2_deferred = run_in_context(
-            kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
-        )
-
-        self.assertFalse(wait_2_deferred.called)
-
-        # let the first lookup complete (in the sentinel context)
-        lookup_1_deferred.callback(None)
-
-        # now the second wait should complete.
-        self.successResultOf(wait_2_deferred)
-
     def test_verify_json_objects_for_server_awaits_previous_requests(self):
         key1 = signedjson.key.generate_signing_key(1)