summary refs log tree commit diff
path: root/tests/crypto/test_keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/crypto/test_keyring.py')
-rw-r--r--tests/crypto/test_keyring.py150
1 files changed, 100 insertions, 50 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index c4f0bbd3dd..70c8e72303 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,6 +19,7 @@ from mock import Mock
 import canonicaljson
 import signedjson.key
 import signedjson.sign
+from nacl.signing import SigningKey
 from signedjson.key import encode_verify_key_base64, get_verify_key
 
 from twisted.internet import defer
@@ -33,6 +34,7 @@ from synapse.crypto.keyring import (
 from synapse.logging.context import (
     LoggingContext,
     PreserveLoggingContext,
+    current_context,
     make_deferred_yieldable,
 )
 from synapse.storage.keys import FetchKeyResult
@@ -82,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         )
 
     def check_context(self, _, expected):
-        self.assertEquals(
-            getattr(LoggingContext.current_context(), "request", None), expected
-        )
+        self.assertEquals(getattr(current_context(), "request", None), expected)
 
     def test_verify_json_objects_for_server_awaits_previous_requests(self):
         key1 = signedjson.key.generate_signing_key(1)
@@ -104,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
         @defer.inlineCallbacks
         def get_perspectives(**kwargs):
-            self.assertEquals(LoggingContext.current_context().request, "11")
+            self.assertEquals(current_context().request, "11")
             with PreserveLoggingContext():
                 yield persp_deferred
             return persp_resp
@@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         kr = keyring.Keyring(self.hs)
 
         key1 = signedjson.key.generate_signing_key(1)
-        r = self.hs.datastore.store_server_verify_keys(
+        r = self.hs.get_datastore().store_server_verify_keys(
             "server9",
             time.time() * 1000,
             [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
@@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         )
 
         key1 = signedjson.key.generate_signing_key(1)
-        r = self.hs.datastore.store_server_verify_keys(
+        r = self.hs.get_datastore().store_server_verify_keys(
             "server9",
             time.time() * 1000,
             [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
@@ -412,34 +412,37 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
             handlers=None, http_client=self.http_client, config=config
         )
 
-    def test_get_keys_from_perspectives(self):
-        # arbitrarily advance the clock a bit
-        self.reactor.advance(100)
-
-        fetcher = PerspectivesKeyFetcher(self.hs)
-
-        SERVER_NAME = "server2"
-        testkey = signedjson.key.generate_signing_key("ver1")
-        testverifykey = signedjson.key.get_verify_key(testkey)
-        testverifykey_id = "ed25519:ver1"
-        VALID_UNTIL_TS = 200 * 1000
+    def build_perspectives_response(
+        self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
+    ) -> dict:
+        """
+        Build a valid perspectives server response to a request for the given key
+        """
+        verify_key = signedjson.key.get_verify_key(signing_key)
+        verifykey_id = "%s:%s" % (verify_key.alg, verify_key.version)
 
-        # valid response
         response = {
-            "server_name": SERVER_NAME,
+            "server_name": server_name,
             "old_verify_keys": {},
-            "valid_until_ts": VALID_UNTIL_TS,
+            "valid_until_ts": valid_until_ts,
             "verify_keys": {
-                testverifykey_id: {
-                    "key": signedjson.key.encode_verify_key_base64(testverifykey)
+                verifykey_id: {
+                    "key": signedjson.key.encode_verify_key_base64(verify_key)
                 }
             },
         }
-
         # the response must be signed by both the origin server and the perspectives
         # server.
-        signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+        signedjson.sign.sign_json(response, server_name, signing_key)
         self.mock_perspective_server.sign_response(response)
+        return response
+
+    def expect_outgoing_key_query(
+        self, expected_server_name: str, expected_key_id: str, response: dict
+    ) -> None:
+        """
+        Tell the mock http client to expect a perspectives-server key query
+        """
 
         def post_json(destination, path, data, **kwargs):
             self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -447,11 +450,79 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
             # check that the request is for the expected key
             q = data["server_keys"]
-            self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
+            self.assertEqual(list(q[expected_server_name].keys()), [expected_key_id])
             return {"server_keys": [response]}
 
         self.http_client.post_json.side_effect = post_json
 
+    def test_get_keys_from_perspectives(self):
+        # arbitrarily advance the clock a bit
+        self.reactor.advance(100)
+
+        fetcher = PerspectivesKeyFetcher(self.hs)
+
+        SERVER_NAME = "server2"
+        testkey = signedjson.key.generate_signing_key("ver1")
+        testverifykey = signedjson.key.get_verify_key(testkey)
+        testverifykey_id = "ed25519:ver1"
+        VALID_UNTIL_TS = 200 * 1000
+
+        response = self.build_perspectives_response(
+            SERVER_NAME, testkey, VALID_UNTIL_TS,
+        )
+
+        self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
+        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+        self.assertIn(SERVER_NAME, keys)
+        k = keys[SERVER_NAME][testverifykey_id]
+        self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+        self.assertEqual(k.verify_key, testverifykey)
+        self.assertEqual(k.verify_key.alg, "ed25519")
+        self.assertEqual(k.verify_key.version, "ver1")
+
+        # check that the perspectives store is correctly updated
+        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
+        key_json = self.get_success(
+            self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+        )
+        res = key_json[lookup_triplet]
+        self.assertEqual(len(res), 1)
+        res = res[0]
+        self.assertEqual(res["key_id"], testverifykey_id)
+        self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
+        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+
+        self.assertEqual(
+            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
+        )
+
+    def test_get_perspectives_own_key(self):
+        """Check that we can get the perspectives server's own keys
+
+        This is slightly complicated by the fact that the perspectives server may
+        use different keys for signing notary responses.
+        """
+
+        # arbitrarily advance the clock a bit
+        self.reactor.advance(100)
+
+        fetcher = PerspectivesKeyFetcher(self.hs)
+
+        SERVER_NAME = self.mock_perspective_server.server_name
+        testkey = signedjson.key.generate_signing_key("ver1")
+        testverifykey = signedjson.key.get_verify_key(testkey)
+        testverifykey_id = "ed25519:ver1"
+        VALID_UNTIL_TS = 200 * 1000
+
+        response = self.build_perspectives_response(
+            SERVER_NAME, testkey, VALID_UNTIL_TS
+        )
+
+        self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
         keys_to_fetch = {SERVER_NAME: {"key1": 0}}
         keys = self.get_success(fetcher.get_keys(keys_to_fetch))
         self.assertIn(SERVER_NAME, keys)
@@ -490,35 +561,14 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         VALID_UNTIL_TS = 200 * 1000
 
         def build_response():
-            # valid response
-            response = {
-                "server_name": SERVER_NAME,
-                "old_verify_keys": {},
-                "valid_until_ts": VALID_UNTIL_TS,
-                "verify_keys": {
-                    testverifykey_id: {
-                        "key": signedjson.key.encode_verify_key_base64(testverifykey)
-                    }
-                },
-            }
-
-            # the response must be signed by both the origin server and the perspectives
-            # server.
-            signedjson.sign.sign_json(response, SERVER_NAME, testkey)
-            self.mock_perspective_server.sign_response(response)
-            return response
+            return self.build_perspectives_response(
+                SERVER_NAME, testkey, VALID_UNTIL_TS
+            )
 
         def get_key_from_perspectives(response):
             fetcher = PerspectivesKeyFetcher(self.hs)
             keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-
-            def post_json(destination, path, data, **kwargs):
-                self.assertEqual(destination, self.mock_perspective_server.server_name)
-                self.assertEqual(path, "/_matrix/key/v2/query")
-                return {"server_keys": [response]}
-
-            self.http_client.post_json.side_effect = post_json
-
+            self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
             return self.get_success(fetcher.get_keys(keys_to_fetch))
 
         # start with a valid response so we can check we are testing the right thing