summary refs log tree commit diff
path: root/tests/crypto/test_keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/crypto/test_keyring.py')
-rw-r--r--tests/crypto/test_keyring.py120
1 files changed, 56 insertions, 64 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2e6e7abf1f..5cf408f21f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -23,6 +23,7 @@ from nacl.signing import SigningKey
 from signedjson.key import encode_verify_key_base64, get_verify_key
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred, ensureDeferred
 
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
@@ -33,7 +34,6 @@ from synapse.crypto.keyring import (
 )
 from synapse.logging.context import (
     LoggingContext,
-    PreserveLoggingContext,
     current_context,
     make_deferred_yieldable,
 )
@@ -68,54 +68,40 @@ class MockPerspectiveServer:
 
 
 class KeyringTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
-        self.mock_perspective_server = MockPerspectiveServer()
-        self.http_client = Mock()
-
-        config = self.default_config()
-        config["trusted_key_servers"] = [
-            {
-                "server_name": self.mock_perspective_server.server_name,
-                "verify_keys": self.mock_perspective_server.get_verify_keys(),
-            }
-        ]
-
-        return self.setup_test_homeserver(
-            handlers=None, http_client=self.http_client, config=config
-        )
-
-    def check_context(self, _, expected):
+    def check_context(self, val, expected):
         self.assertEquals(getattr(current_context(), "request", None), expected)
+        return val
 
     def test_verify_json_objects_for_server_awaits_previous_requests(self):
-        key1 = signedjson.key.generate_signing_key(1)
+        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher.get_keys = Mock()
+        kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
 
-        kr = keyring.Keyring(self.hs)
+        # a signed object that we are going to try to validate
+        key1 = signedjson.key.generate_signing_key(1)
         json1 = {}
         signedjson.sign.sign_json(json1, "server10", key1)
 
-        persp_resp = {
-            "server_keys": [
-                self.mock_perspective_server.get_signed_key(
-                    "server10", signedjson.key.get_verify_key(key1)
-                )
-            ]
-        }
-        persp_deferred = defer.Deferred()
+        # start off a first set of lookups. We make the mock fetcher block until this
+        # deferred completes.
+        first_lookup_deferred = Deferred()
+
+        async def first_lookup_fetch(keys_to_fetch):
+            self.assertEquals(current_context().request, "context_11")
+            self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
 
-        async def get_perspectives(**kwargs):
-            self.assertEquals(current_context().request, "11")
-            with PreserveLoggingContext():
-                await persp_deferred
-            return persp_resp
+            await make_deferred_yieldable(first_lookup_deferred)
+            return {
+                "server10": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+                }
+            }
 
-        self.http_client.post_json.side_effect = get_perspectives
+        mock_fetcher.get_keys.side_effect = first_lookup_fetch
 
-        # start off a first set of lookups
-        @defer.inlineCallbacks
-        def first_lookup():
-            with LoggingContext("11") as context_11:
-                context_11.request = "11"
+        async def first_lookup():
+            with LoggingContext("context_11") as context_11:
+                context_11.request = "context_11"
 
                 res_deferreds = kr.verify_json_objects_for_server(
                     [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
@@ -124,7 +110,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 # the unsigned json should be rejected pretty quickly
                 self.assertTrue(res_deferreds[1].called)
                 try:
-                    yield res_deferreds[1]
+                    await res_deferreds[1]
                     self.assertFalse("unsigned json didn't cause a failure")
                 except SynapseError:
                     pass
@@ -132,45 +118,51 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 self.assertFalse(res_deferreds[0].called)
                 res_deferreds[0].addBoth(self.check_context, None)
 
-                yield make_deferred_yieldable(res_deferreds[0])
+                await make_deferred_yieldable(res_deferreds[0])
 
-                # let verify_json_objects_for_server finish its work before we kill the
-                # logcontext
-                yield self.clock.sleep(0)
+        d0 = ensureDeferred(first_lookup())
 
-        d0 = first_lookup()
-
-        # wait a tick for it to send the request to the perspectives server
-        # (it first tries the datastore)
-        self.pump()
-        self.http_client.post_json.assert_called_once()
+        mock_fetcher.get_keys.assert_called_once()
 
         # a second request for a server with outstanding requests
         # should block rather than start a second call
-        @defer.inlineCallbacks
-        def second_lookup():
-            with LoggingContext("12") as context_12:
-                context_12.request = "12"
-                self.http_client.post_json.reset_mock()
-                self.http_client.post_json.return_value = defer.Deferred()
+
+        async def second_lookup_fetch(keys_to_fetch):
+            self.assertEquals(current_context().request, "context_12")
+            return {
+                "server10": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+                }
+            }
+
+        mock_fetcher.get_keys.reset_mock()
+        mock_fetcher.get_keys.side_effect = second_lookup_fetch
+        second_lookup_state = [0]
+
+        async def second_lookup():
+            with LoggingContext("context_12") as context_12:
+                context_12.request = "context_12"
 
                 res_deferreds_2 = kr.verify_json_objects_for_server(
                     [("server10", json1, 0, "test")]
                 )
                 res_deferreds_2[0].addBoth(self.check_context, None)
-                yield make_deferred_yieldable(res_deferreds_2[0])
+                second_lookup_state[0] = 1
+                await make_deferred_yieldable(res_deferreds_2[0])
+                second_lookup_state[0] = 2
 
-                # let verify_json_objects_for_server finish its work before we kill the
-                # logcontext
-                yield self.clock.sleep(0)
-
-        d2 = second_lookup()
+        d2 = ensureDeferred(second_lookup())
 
         self.pump()
-        self.http_client.post_json.assert_not_called()
+        # the second request should be pending, but the fetcher should not yet have been
+        # called
+        self.assertEqual(second_lookup_state[0], 1)
+        mock_fetcher.get_keys.assert_not_called()
 
         # complete the first request
-        persp_deferred.callback(persp_resp)
+        first_lookup_deferred.callback(None)
+
+        # and now both verifications should succeed.
         self.get_success(d0)
         self.get_success(d2)