summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8398.bugfix1
-rw-r--r--synapse/crypto/keyring.py70
-rw-r--r--tests/crypto/test_keyring.py120
3 files changed, 101 insertions, 90 deletions
diff --git a/changelog.d/8398.bugfix b/changelog.d/8398.bugfix
new file mode 100644
index 0000000000..e432aeebf1
--- /dev/null
+++ b/changelog.d/8398.bugfix
@@ -0,0 +1 @@
+Fix "Re-starting finished log context" warning when receiving an event we already had over federation.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 42e4087a92..c04ad77cf9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -42,7 +42,6 @@ from synapse.api.errors import (
 )
 from synapse.logging.context import (
     PreserveLoggingContext,
-    current_context,
     make_deferred_yieldable,
     preserve_fn,
     run_in_background,
@@ -233,8 +232,6 @@ class Keyring:
         """
 
         try:
-            ctx = current_context()
-
             # map from server name to a set of outstanding request ids
             server_to_request_ids = {}
 
@@ -265,12 +262,8 @@ class Keyring:
 
                 # if there are no more requests for this server, we can drop the lock.
                 if not server_requests:
-                    with PreserveLoggingContext(ctx):
-                        logger.debug("Releasing key lookup lock on %s", server_name)
-
-                    # ... but not immediately, as that can cause stack explosions if
-                    # we get a long queue of lookups.
-                    self.clock.call_later(0, drop_server_lock, server_name)
+                    logger.debug("Releasing key lookup lock on %s", server_name)
+                    drop_server_lock(server_name)
 
                 return res
 
@@ -335,20 +328,32 @@ class Keyring:
                         )
 
                     # look for any requests which weren't satisfied
-                    with PreserveLoggingContext():
-                        for verify_request in remaining_requests:
-                            verify_request.key_ready.errback(
-                                SynapseError(
-                                    401,
-                                    "No key for %s with ids in %s (min_validity %i)"
-                                    % (
-                                        verify_request.server_name,
-                                        verify_request.key_ids,
-                                        verify_request.minimum_valid_until_ts,
-                                    ),
-                                    Codes.UNAUTHORIZED,
-                                )
+                    while remaining_requests:
+                        verify_request = remaining_requests.pop()
+                        rq_str = (
+                            "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
+                            % (
+                                verify_request.server_name,
+                                verify_request.key_ids,
+                                verify_request.minimum_valid_until_ts,
                             )
+                        )
+
+                        # If we run the errback immediately, it may cancel our
+                        # loggingcontext while we are still in it, so instead we
+                        # schedule it for the next time round the reactor.
+                        #
+                        # (this also ensures that we don't get a stack overflow if we
+                        # has a massive queue of lookups waiting for this server).
+                        self.clock.call_later(
+                            0,
+                            verify_request.key_ready.errback,
+                            SynapseError(
+                                401,
+                                "Failed to find any key to satisfy %s" % (rq_str,),
+                                Codes.UNAUTHORIZED,
+                            ),
+                        )
             except Exception as err:
                 # we don't really expect to get here, because any errors should already
                 # have been caught and logged. But if we do, let's log the error and make
@@ -410,10 +415,23 @@ class Keyring:
                     # key was not valid at this point
                     continue
 
-                with PreserveLoggingContext():
-                    verify_request.key_ready.callback(
-                        (server_name, key_id, fetch_key_result.verify_key)
-                    )
+                # we have a valid key for this request. If we run the callback
+                # immediately, it may cancel our loggingcontext while we are still in
+                # it, so instead we schedule it for the next time round the reactor.
+                #
+                # (this also ensures that we don't get a stack overflow if we had
+                # a massive queue of lookups waiting for this server).
+                logger.debug(
+                    "Found key %s:%s for %s",
+                    server_name,
+                    key_id,
+                    verify_request.request_name,
+                )
+                self.clock.call_later(
+                    0,
+                    verify_request.key_ready.callback,
+                    (server_name, key_id, fetch_key_result.verify_key),
+                )
                 completed.append(verify_request)
                 break
 
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2e6e7abf1f..5cf408f21f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -23,6 +23,7 @@ from nacl.signing import SigningKey
 from signedjson.key import encode_verify_key_base64, get_verify_key
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred, ensureDeferred
 
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
@@ -33,7 +34,6 @@ from synapse.crypto.keyring import (
 )
 from synapse.logging.context import (
     LoggingContext,
-    PreserveLoggingContext,
     current_context,
     make_deferred_yieldable,
 )
@@ -68,54 +68,40 @@ class MockPerspectiveServer:
 
 
 class KeyringTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
-        self.mock_perspective_server = MockPerspectiveServer()
-        self.http_client = Mock()
-
-        config = self.default_config()
-        config["trusted_key_servers"] = [
-            {
-                "server_name": self.mock_perspective_server.server_name,
-                "verify_keys": self.mock_perspective_server.get_verify_keys(),
-            }
-        ]
-
-        return self.setup_test_homeserver(
-            handlers=None, http_client=self.http_client, config=config
-        )
-
-    def check_context(self, _, expected):
+    def check_context(self, val, expected):
         self.assertEquals(getattr(current_context(), "request", None), expected)
+        return val
 
     def test_verify_json_objects_for_server_awaits_previous_requests(self):
-        key1 = signedjson.key.generate_signing_key(1)
+        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher.get_keys = Mock()
+        kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
 
-        kr = keyring.Keyring(self.hs)
+        # a signed object that we are going to try to validate
+        key1 = signedjson.key.generate_signing_key(1)
         json1 = {}
         signedjson.sign.sign_json(json1, "server10", key1)
 
-        persp_resp = {
-            "server_keys": [
-                self.mock_perspective_server.get_signed_key(
-                    "server10", signedjson.key.get_verify_key(key1)
-                )
-            ]
-        }
-        persp_deferred = defer.Deferred()
+        # start off a first set of lookups. We make the mock fetcher block until this
+        # deferred completes.
+        first_lookup_deferred = Deferred()
+
+        async def first_lookup_fetch(keys_to_fetch):
+            self.assertEquals(current_context().request, "context_11")
+            self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
 
-        async def get_perspectives(**kwargs):
-            self.assertEquals(current_context().request, "11")
-            with PreserveLoggingContext():
-                await persp_deferred
-            return persp_resp
+            await make_deferred_yieldable(first_lookup_deferred)
+            return {
+                "server10": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+                }
+            }
 
-        self.http_client.post_json.side_effect = get_perspectives
+        mock_fetcher.get_keys.side_effect = first_lookup_fetch
 
-        # start off a first set of lookups
-        @defer.inlineCallbacks
-        def first_lookup():
-            with LoggingContext("11") as context_11:
-                context_11.request = "11"
+        async def first_lookup():
+            with LoggingContext("context_11") as context_11:
+                context_11.request = "context_11"
 
                 res_deferreds = kr.verify_json_objects_for_server(
                     [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
@@ -124,7 +110,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 # the unsigned json should be rejected pretty quickly
                 self.assertTrue(res_deferreds[1].called)
                 try:
-                    yield res_deferreds[1]
+                    await res_deferreds[1]
                     self.assertFalse("unsigned json didn't cause a failure")
                 except SynapseError:
                     pass
@@ -132,45 +118,51 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 self.assertFalse(res_deferreds[0].called)
                 res_deferreds[0].addBoth(self.check_context, None)
 
-                yield make_deferred_yieldable(res_deferreds[0])
+                await make_deferred_yieldable(res_deferreds[0])
 
-                # let verify_json_objects_for_server finish its work before we kill the
-                # logcontext
-                yield self.clock.sleep(0)
+        d0 = ensureDeferred(first_lookup())
 
-        d0 = first_lookup()
-
-        # wait a tick for it to send the request to the perspectives server
-        # (it first tries the datastore)
-        self.pump()
-        self.http_client.post_json.assert_called_once()
+        mock_fetcher.get_keys.assert_called_once()
 
         # a second request for a server with outstanding requests
         # should block rather than start a second call
-        @defer.inlineCallbacks
-        def second_lookup():
-            with LoggingContext("12") as context_12:
-                context_12.request = "12"
-                self.http_client.post_json.reset_mock()
-                self.http_client.post_json.return_value = defer.Deferred()
+
+        async def second_lookup_fetch(keys_to_fetch):
+            self.assertEquals(current_context().request, "context_12")
+            return {
+                "server10": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+                }
+            }
+
+        mock_fetcher.get_keys.reset_mock()
+        mock_fetcher.get_keys.side_effect = second_lookup_fetch
+        second_lookup_state = [0]
+
+        async def second_lookup():
+            with LoggingContext("context_12") as context_12:
+                context_12.request = "context_12"
 
                 res_deferreds_2 = kr.verify_json_objects_for_server(
                     [("server10", json1, 0, "test")]
                 )
                 res_deferreds_2[0].addBoth(self.check_context, None)
-                yield make_deferred_yieldable(res_deferreds_2[0])
+                second_lookup_state[0] = 1
+                await make_deferred_yieldable(res_deferreds_2[0])
+                second_lookup_state[0] = 2
 
-                # let verify_json_objects_for_server finish its work before we kill the
-                # logcontext
-                yield self.clock.sleep(0)
-
-        d2 = second_lookup()
+        d2 = ensureDeferred(second_lookup())
 
         self.pump()
-        self.http_client.post_json.assert_not_called()
+        # the second request should be pending, but the fetcher should not yet have been
+        # called
+        self.assertEqual(second_lookup_state[0], 1)
+        mock_fetcher.get_keys.assert_not_called()
 
         # complete the first request
-        persp_deferred.callback(persp_resp)
+        first_lookup_deferred.callback(None)
+
+        # and now both verifications should succeed.
         self.get_success(d0)
         self.get_success(d2)