summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/5244.misc1
-rw-r--r--synapse/crypto/keyring.py315
-rw-r--r--tests/crypto/test_keyring.py34
3 files changed, 204 insertions, 146 deletions
diff --git a/changelog.d/5244.misc b/changelog.d/5244.misc
new file mode 100644
index 0000000000..9cc1fb869d
--- /dev/null
+++ b/changelog.d/5244.misc
@@ -0,0 +1 @@
+Refactor synapse.crypto.keyring to use a KeyFetcher interface.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 14a27288fd..eaf41b983c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -80,12 +80,13 @@ class KeyLookupError(ValueError):
 
 class Keyring(object):
     def __init__(self, hs):
-        self.store = hs.get_datastore()
         self.clock = hs.get_clock()
-        self.client = hs.get_http_client()
-        self.config = hs.get_config()
-        self.perspective_servers = self.config.perspectives
-        self.hs = hs
+
+        self._key_fetchers = (
+            StoreKeyFetcher(hs),
+            PerspectivesKeyFetcher(hs),
+            ServerKeyFetcher(hs),
+        )
 
         # map from server name to Deferred. Has an entry for each server with
         # an ongoing key download; the Deferred completes once the download
@@ -271,13 +272,6 @@ class Keyring(object):
             verify_requests (list[VerifyKeyRequest]): list of verify requests
         """
 
-        # These are functions that produce keys given a list of key ids
-        key_fetch_fns = (
-            self.get_keys_from_store,  # First try the local store
-            self.get_keys_from_perspectives,  # Then try via perspectives
-            self.get_keys_from_server,  # Then try directly
-        )
-
         @defer.inlineCallbacks
         def do_iterations():
             with Measure(self.clock, "get_server_verify_keys"):
@@ -288,8 +282,8 @@ class Keyring(object):
                         verify_request.key_ids
                     )
 
-                for fn in key_fetch_fns:
-                    results = yield fn(missing_keys.items())
+                for f in self._key_fetchers:
+                    results = yield f.get_keys(missing_keys.items())
 
                     # We now need to figure out which verify requests we have keys
                     # for and which we don't
@@ -348,8 +342,9 @@ class Keyring(object):
 
         run_in_background(do_iterations).addErrback(on_err)
 
-    @defer.inlineCallbacks
-    def get_keys_from_store(self, server_name_and_key_ids):
+
+class KeyFetcher(object):
+    def get_keys(self, server_name_and_key_ids):
         """
         Args:
             server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
@@ -359,6 +354,18 @@ class Keyring(object):
             Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
                 map from server_name -> key_id -> FetchKeyResult
         """
+        raise NotImplementedError
+
+
+class StoreKeyFetcher(KeyFetcher):
+    """KeyFetcher impl which fetches keys from our data store"""
+
+    def __init__(self, hs):
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def get_keys(self, server_name_and_key_ids):
+        """see KeyFetcher.get_keys"""
         keys_to_fetch = (
             (server_name, key_id)
             for server_name, key_ids in server_name_and_key_ids
@@ -370,8 +377,127 @@ class Keyring(object):
             keys.setdefault(server_name, {})[key_id] = key
         defer.returnValue(keys)
 
+
+class BaseV2KeyFetcher(object):
+    def __init__(self, hs):
+        self.store = hs.get_datastore()
+        self.config = hs.get_config()
+
+    @defer.inlineCallbacks
+    def process_v2_response(
+        self, from_server, response_json, time_added_ms, requested_ids=[]
+    ):
+        """Parse a 'Server Keys' structure from the result of a /key request
+
+        This is used to parse either the entirety of the response from
+        GET /_matrix/key/v2/server, or a single entry from the list returned by
+        POST /_matrix/key/v2/query.
+
+        Checks that each signature in the response that claims to come from the origin
+        server is valid. (Does not check that there actually is such a signature, for
+        some reason.)
+
+        Stores the json in server_keys_json so that it can be used for future responses
+        to /_matrix/key/v2/query.
+
+        Args:
+            from_server (str): the name of the server producing this result: either
+                the origin server for a /_matrix/key/v2/server request, or the notary
+                for a /_matrix/key/v2/query.
+
+            response_json (dict): the json-decoded Server Keys response object
+
+            time_added_ms (int): the timestamp to record in server_keys_json
+
+            requested_ids (iterable[str]): a list of the key IDs that were requested.
+                We will store the json for these key ids as well as any that are
+                actually in the response
+
+        Returns:
+            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+        """
+        ts_valid_until_ms = response_json[u"valid_until_ts"]
+
+        # start by extracting the keys from the response, since they may be required
+        # to validate the signature on the response.
+        verify_keys = {}
+        for key_id, key_data in response_json["verify_keys"].items():
+            if is_signing_algorithm_supported(key_id):
+                key_base64 = key_data["key"]
+                key_bytes = decode_base64(key_base64)
+                verify_key = decode_verify_key_bytes(key_id, key_bytes)
+                verify_keys[key_id] = FetchKeyResult(
+                    verify_key=verify_key, valid_until_ts=ts_valid_until_ms
+                )
+
+        # TODO: improve this signature checking
+        server_name = response_json["server_name"]
+        for key_id in response_json["signatures"].get(server_name, {}):
+            if key_id not in verify_keys:
+                raise KeyLookupError(
+                    "Key response must include verification keys for all signatures"
+                )
+
+            verify_signed_json(
+                response_json, server_name, verify_keys[key_id].verify_key
+            )
+
+        for key_id, key_data in response_json["old_verify_keys"].items():
+            if is_signing_algorithm_supported(key_id):
+                key_base64 = key_data["key"]
+                key_bytes = decode_base64(key_base64)
+                verify_key = decode_verify_key_bytes(key_id, key_bytes)
+                verify_keys[key_id] = FetchKeyResult(
+                    verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
+                )
+
+        # re-sign the json with our own key, so that it is ready if we are asked to
+        # give it out as a notary server
+        signed_key_json = sign_json(
+            response_json, self.config.server_name, self.config.signing_key[0]
+        )
+
+        signed_key_json_bytes = encode_canonical_json(signed_key_json)
+
+        # for reasons I don't quite understand, we store this json for the key ids we
+        # requested, as well as those we got.
+        updated_key_ids = set(requested_ids)
+        updated_key_ids.update(verify_keys)
+
+        yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self.store.store_server_keys_json,
+                        server_name=server_name,
+                        key_id=key_id,
+                        from_server=from_server,
+                        ts_now_ms=time_added_ms,
+                        ts_expires_ms=ts_valid_until_ms,
+                        key_json_bytes=signed_key_json_bytes,
+                    )
+                    for key_id in updated_key_ids
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
+
+        defer.returnValue(verify_keys)
+
+
+class PerspectivesKeyFetcher(BaseV2KeyFetcher):
+    """KeyFetcher impl which fetches keys from the "perspectives" servers"""
+
+    def __init__(self, hs):
+        super(PerspectivesKeyFetcher, self).__init__(hs)
+        self.clock = hs.get_clock()
+        self.client = hs.get_http_client()
+        self.perspective_servers = self.config.perspectives
+
     @defer.inlineCallbacks
-    def get_keys_from_perspectives(self, server_name_and_key_ids):
+    def get_keys(self, server_name_and_key_ids):
+        """see KeyFetcher.get_keys"""
+
         @defer.inlineCallbacks
         def get_key(perspective_name, perspective_keys):
             try:
@@ -409,28 +535,6 @@ class Keyring(object):
         defer.returnValue(union_of_keys)
 
     @defer.inlineCallbacks
-    def get_keys_from_server(self, server_name_and_key_ids):
-        results = yield logcontext.make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self.get_server_verify_key_v2_direct, server_name, key_ids
-                    )
-                    for server_name, key_ids in server_name_and_key_ids
-                ],
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError)
-        )
-
-        merged = {}
-        for result in results:
-            merged.update(result)
-
-        defer.returnValue(
-            {server_name: keys for server_name, keys in merged.items() if keys}
-        )
-
-    @defer.inlineCallbacks
     def get_server_verify_key_v2_indirect(
         self, server_names_and_key_ids, perspective_name, perspective_keys
     ):
@@ -520,6 +624,38 @@ class Keyring(object):
 
         defer.returnValue(keys)
 
+
+class ServerKeyFetcher(BaseV2KeyFetcher):
+    """KeyFetcher impl which fetches keys from the origin servers"""
+
+    def __init__(self, hs):
+        super(ServerKeyFetcher, self).__init__(hs)
+        self.clock = hs.get_clock()
+        self.client = hs.get_http_client()
+
+    @defer.inlineCallbacks
+    def get_keys(self, server_name_and_key_ids):
+        """see KeyFetcher.get_keys"""
+        results = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self.get_server_verify_key_v2_direct, server_name, key_ids
+                    )
+                    for server_name, key_ids in server_name_and_key_ids
+                ],
+                consumeErrors=True,
+            ).addErrback(unwrapFirstError)
+        )
+
+        merged = {}
+        for result in results:
+            merged.update(result)
+
+        defer.returnValue(
+            {server_name: keys for server_name, keys in merged.items() if keys}
+        )
+
     @defer.inlineCallbacks
     def get_server_verify_key_v2_direct(self, server_name, key_ids):
         keys = {}  # type: dict[str, FetchKeyResult]
@@ -568,107 +704,6 @@ class Keyring(object):
 
         defer.returnValue({server_name: keys})
 
-    @defer.inlineCallbacks
-    def process_v2_response(
-        self, from_server, response_json, time_added_ms, requested_ids=[]
-    ):
-        """Parse a 'Server Keys' structure from the result of a /key request
-
-        This is used to parse either the entirety of the response from
-        GET /_matrix/key/v2/server, or a single entry from the list returned by
-        POST /_matrix/key/v2/query.
-
-        Checks that each signature in the response that claims to come from the origin
-        server is valid. (Does not check that there actually is such a signature, for
-        some reason.)
-
-        Stores the json in server_keys_json so that it can be used for future responses
-        to /_matrix/key/v2/query.
-
-        Args:
-            from_server (str): the name of the server producing this result: either
-                the origin server for a /_matrix/key/v2/server request, or the notary
-                for a /_matrix/key/v2/query.
-
-            response_json (dict): the json-decoded Server Keys response object
-
-            time_added_ms (int): the timestamp to record in server_keys_json
-
-            requested_ids (iterable[str]): a list of the key IDs that were requested.
-                We will store the json for these key ids as well as any that are
-                actually in the response
-
-        Returns:
-            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
-        """
-        ts_valid_until_ms = response_json[u"valid_until_ts"]
-
-        # start by extracting the keys from the response, since they may be required
-        # to validate the signature on the response.
-        verify_keys = {}
-        for key_id, key_data in response_json["verify_keys"].items():
-            if is_signing_algorithm_supported(key_id):
-                key_base64 = key_data["key"]
-                key_bytes = decode_base64(key_base64)
-                verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_keys[key_id] = FetchKeyResult(
-                    verify_key=verify_key, valid_until_ts=ts_valid_until_ms
-                )
-
-        # TODO: improve this signature checking
-        server_name = response_json["server_name"]
-        for key_id in response_json["signatures"].get(server_name, {}):
-            if key_id not in verify_keys:
-                raise KeyLookupError(
-                    "Key response must include verification keys for all signatures"
-                )
-
-            verify_signed_json(
-                response_json, server_name, verify_keys[key_id].verify_key
-            )
-
-        for key_id, key_data in response_json["old_verify_keys"].items():
-            if is_signing_algorithm_supported(key_id):
-                key_base64 = key_data["key"]
-                key_bytes = decode_base64(key_base64)
-                verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_keys[key_id] = FetchKeyResult(
-                    verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
-                )
-
-        # re-sign the json with our own key, so that it is ready if we are asked to
-        # give it out as a notary server
-        signed_key_json = sign_json(
-            response_json, self.config.server_name, self.config.signing_key[0]
-        )
-
-        signed_key_json_bytes = encode_canonical_json(signed_key_json)
-
-        # for reasons I don't quite understand, we store this json for the key ids we
-        # requested, as well as those we got.
-        updated_key_ids = set(requested_ids)
-        updated_key_ids.update(verify_keys)
-
-        yield logcontext.make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self.store.store_server_keys_json,
-                        server_name=server_name,
-                        key_id=key_id,
-                        from_server=from_server,
-                        ts_now_ms=time_added_ms,
-                        ts_expires_ms=ts_valid_until_ms,
-                        key_json_bytes=signed_key_json_bytes,
-                    )
-                    for key_id in updated_key_ids
-                ],
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError)
-        )
-
-        defer.returnValue(verify_keys)
-
 
 @defer.inlineCallbacks
 def _handle_key_deferred(verify_request):
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 83de32b05d..de61bad15d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -24,7 +24,11 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import (
+    KeyLookupError,
+    PerspectivesKeyFetcher,
+    ServerKeyFetcher,
+)
 from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext
 from synapse.util.logcontext import LoggingContext
@@ -218,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         self.assertFalse(d.called)
         self.get_success(d)
 
+
+class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        self.http_client = Mock()
+        hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+        return hs
+
     def test_get_keys_from_server(self):
         # arbitrarily advance the clock a bit
         self.reactor.advance(100)
 
         SERVER_NAME = "server2"
-        kr = keyring.Keyring(self.hs)
+        fetcher = ServerKeyFetcher(self.hs)
         testkey = signedjson.key.generate_signing_key("ver1")
         testverifykey = signedjson.key.get_verify_key(testkey)
         testverifykey_id = "ed25519:ver1"
@@ -250,7 +261,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         self.http_client.get_json.side_effect = get_json
 
         server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+        keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
         k = keys[SERVER_NAME][testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
@@ -278,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # change the server name: it should cause a rejection
         response["server_name"] = "OTHER_SERVER"
         self.get_failure(
-            kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
+            fetcher.get_keys(server_name_and_key_ids), KeyLookupError
         )
 
+
+class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        self.mock_perspective_server = MockPerspectiveServer()
+        self.http_client = Mock()
+        hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+        keys = self.mock_perspective_server.get_verify_keys()
+        hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+        return hs
+
     def test_get_keys_from_perspectives(self):
         # arbitrarily advance the clock a bit
         self.reactor.advance(100)
 
+        fetcher = PerspectivesKeyFetcher(self.hs)
+
         SERVER_NAME = "server2"
-        kr = keyring.Keyring(self.hs)
         testkey = signedjson.key.generate_signing_key("ver1")
         testverifykey = signedjson.key.get_verify_key(testkey)
         testverifykey_id = "ed25519:ver1"
@@ -320,7 +342,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         self.http_client.post_json.side_effect = post_json
 
         server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+        keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
         self.assertIn(SERVER_NAME, keys)
         k = keys[SERVER_NAME][testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)