summary refs log tree commit diff
path: root/tests/crypto
diff options
context:
space:
mode:
Diffstat (limited to 'tests/crypto')
-rw-r--r--tests/crypto/test_keyring.py34
1 files changed, 28 insertions, 6 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 83de32b05d..de61bad15d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -24,7 +24,11 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import (
+    KeyLookupError,
+    PerspectivesKeyFetcher,
+    ServerKeyFetcher,
+)
 from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext
 from synapse.util.logcontext import LoggingContext
@@ -218,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         self.assertFalse(d.called)
         self.get_success(d)
 
+
+class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        self.http_client = Mock()
+        hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+        return hs
+
     def test_get_keys_from_server(self):
         # arbitrarily advance the clock a bit
         self.reactor.advance(100)
 
         SERVER_NAME = "server2"
-        kr = keyring.Keyring(self.hs)
+        fetcher = ServerKeyFetcher(self.hs)
         testkey = signedjson.key.generate_signing_key("ver1")
         testverifykey = signedjson.key.get_verify_key(testkey)
         testverifykey_id = "ed25519:ver1"
@@ -250,7 +261,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         self.http_client.get_json.side_effect = get_json
 
         server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+        keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
         k = keys[SERVER_NAME][testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
@@ -278,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # change the server name: it should cause a rejection
         response["server_name"] = "OTHER_SERVER"
         self.get_failure(
-            kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
+            fetcher.get_keys(server_name_and_key_ids), KeyLookupError
         )
 
+
+class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        self.mock_perspective_server = MockPerspectiveServer()
+        self.http_client = Mock()
+        hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+        keys = self.mock_perspective_server.get_verify_keys()
+        hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+        return hs
+
     def test_get_keys_from_perspectives(self):
         # arbitrarily advance the clock a bit
         self.reactor.advance(100)
 
+        fetcher = PerspectivesKeyFetcher(self.hs)
+
         SERVER_NAME = "server2"
-        kr = keyring.Keyring(self.hs)
         testkey = signedjson.key.generate_signing_key("ver1")
         testverifykey = signedjson.key.get_verify_key(testkey)
         testverifykey_id = "ed25519:ver1"
@@ -320,7 +342,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         self.http_client.post_json.side_effect = post_json
 
         server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+        keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
         self.assertIn(SERVER_NAME, keys)
         k = keys[SERVER_NAME][testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)