diff --git a/changelog.d/5244.misc b/changelog.d/5244.misc
new file mode 100644
index 0000000000..9cc1fb869d
--- /dev/null
+++ b/changelog.d/5244.misc
@@ -0,0 +1 @@
+Refactor synapse.crypto.keyring to use a KeyFetcher interface.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 14a27288fd..eaf41b983c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -80,12 +80,13 @@ class KeyLookupError(ValueError):
class Keyring(object):
def __init__(self, hs):
- self.store = hs.get_datastore()
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
- self.config = hs.get_config()
- self.perspective_servers = self.config.perspectives
- self.hs = hs
+
+ self._key_fetchers = (
+ StoreKeyFetcher(hs),
+ PerspectivesKeyFetcher(hs),
+ ServerKeyFetcher(hs),
+ )
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
@@ -271,13 +272,6 @@ class Keyring(object):
verify_requests (list[VerifyKeyRequest]): list of verify requests
"""
- # These are functions that produce keys given a list of key ids
- key_fetch_fns = (
- self.get_keys_from_store, # First try the local store
- self.get_keys_from_perspectives, # Then try via perspectives
- self.get_keys_from_server, # Then try directly
- )
-
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
@@ -288,8 +282,8 @@ class Keyring(object):
verify_request.key_ids
)
- for fn in key_fetch_fns:
- results = yield fn(missing_keys.items())
+ for f in self._key_fetchers:
+ results = yield f.get_keys(missing_keys.items())
# We now need to figure out which verify requests we have keys
# for and which we don't
@@ -348,8 +342,9 @@ class Keyring(object):
run_in_background(do_iterations).addErrback(on_err)
- @defer.inlineCallbacks
- def get_keys_from_store(self, server_name_and_key_ids):
+
+class KeyFetcher(object):
+ def get_keys(self, server_name_and_key_ids):
"""
Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
@@ -359,6 +354,18 @@ class Keyring(object):
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
map from server_name -> key_id -> FetchKeyResult
"""
+ raise NotImplementedError
+
+
+class StoreKeyFetcher(KeyFetcher):
+ """KeyFetcher impl which fetches keys from our data store"""
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def get_keys(self, server_name_and_key_ids):
+ """see KeyFetcher.get_keys"""
keys_to_fetch = (
(server_name, key_id)
for server_name, key_ids in server_name_and_key_ids
@@ -370,8 +377,127 @@ class Keyring(object):
keys.setdefault(server_name, {})[key_id] = key
defer.returnValue(keys)
+
+class BaseV2KeyFetcher(object):
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.config = hs.get_config()
+
+ @defer.inlineCallbacks
+ def process_v2_response(
+ self, from_server, response_json, time_added_ms, requested_ids=[]
+ ):
+ """Parse a 'Server Keys' structure from the result of a /key request
+
+ This is used to parse either the entirety of the response from
+ GET /_matrix/key/v2/server, or a single entry from the list returned by
+ POST /_matrix/key/v2/query.
+
+ Checks that each signature in the response that claims to come from the origin
+ server is valid. (Does not check that there actually is such a signature, for
+ some reason.)
+
+ Stores the json in server_keys_json so that it can be used for future responses
+ to /_matrix/key/v2/query.
+
+ Args:
+ from_server (str): the name of the server producing this result: either
+ the origin server for a /_matrix/key/v2/server request, or the notary
+ for a /_matrix/key/v2/query.
+
+ response_json (dict): the json-decoded Server Keys response object
+
+ time_added_ms (int): the timestamp to record in server_keys_json
+
+ requested_ids (iterable[str]): a list of the key IDs that were requested.
+ We will store the json for these key ids as well as any that are
+ actually in the response
+
+ Returns:
+ Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ """
+ ts_valid_until_ms = response_json[u"valid_until_ts"]
+
+ # start by extracting the keys from the response, since they may be required
+ # to validate the signature on the response.
+ verify_keys = {}
+ for key_id, key_data in response_json["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_keys[key_id] = FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=ts_valid_until_ms
+ )
+
+ # TODO: improve this signature checking
+ server_name = response_json["server_name"]
+ for key_id in response_json["signatures"].get(server_name, {}):
+ if key_id not in verify_keys:
+ raise KeyLookupError(
+ "Key response must include verification keys for all signatures"
+ )
+
+ verify_signed_json(
+ response_json, server_name, verify_keys[key_id].verify_key
+ )
+
+ for key_id, key_data in response_json["old_verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_keys[key_id] = FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
+ )
+
+ # re-sign the json with our own key, so that it is ready if we are asked to
+ # give it out as a notary server
+ signed_key_json = sign_json(
+ response_json, self.config.server_name, self.config.signing_key[0]
+ )
+
+ signed_key_json_bytes = encode_canonical_json(signed_key_json)
+
+ # for reasons I don't quite understand, we store this json for the key ids we
+ # requested, as well as those we got.
+ updated_key_ids = set(requested_ids)
+ updated_key_ids.update(verify_keys)
+
+ yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self.store.store_server_keys_json,
+ server_name=server_name,
+ key_id=key_id,
+ from_server=from_server,
+ ts_now_ms=time_added_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
+ )
+ for key_id in updated_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ defer.returnValue(verify_keys)
+
+
+class PerspectivesKeyFetcher(BaseV2KeyFetcher):
+ """KeyFetcher impl which fetches keys from the "perspectives" servers"""
+
+ def __init__(self, hs):
+ super(PerspectivesKeyFetcher, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+ self.perspective_servers = self.config.perspectives
+
@defer.inlineCallbacks
- def get_keys_from_perspectives(self, server_name_and_key_ids):
+ def get_keys(self, server_name_and_key_ids):
+ """see KeyFetcher.get_keys"""
+
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
@@ -409,28 +535,6 @@ class Keyring(object):
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
- def get_keys_from_server(self, server_name_and_key_ids):
- results = yield logcontext.make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.get_server_verify_key_v2_direct, server_name, key_ids
- )
- for server_name, key_ids in server_name_and_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
- )
-
- merged = {}
- for result in results:
- merged.update(result)
-
- defer.returnValue(
- {server_name: keys for server_name, keys in merged.items() if keys}
- )
-
- @defer.inlineCallbacks
def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys
):
@@ -520,6 +624,38 @@ class Keyring(object):
defer.returnValue(keys)
+
+class ServerKeyFetcher(BaseV2KeyFetcher):
+ """KeyFetcher impl which fetches keys from the origin servers"""
+
+ def __init__(self, hs):
+ super(ServerKeyFetcher, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+
+ @defer.inlineCallbacks
+ def get_keys(self, server_name_and_key_ids):
+ """see KeyFetcher.get_keys"""
+ results = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self.get_server_verify_key_v2_direct, server_name, key_ids
+ )
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ merged = {}
+ for result in results:
+ merged.update(result)
+
+ defer.returnValue(
+ {server_name: keys for server_name, keys in merged.items() if keys}
+ )
+
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} # type: dict[str, FetchKeyResult]
@@ -568,107 +704,6 @@ class Keyring(object):
defer.returnValue({server_name: keys})
- @defer.inlineCallbacks
- def process_v2_response(
- self, from_server, response_json, time_added_ms, requested_ids=[]
- ):
- """Parse a 'Server Keys' structure from the result of a /key request
-
- This is used to parse either the entirety of the response from
- GET /_matrix/key/v2/server, or a single entry from the list returned by
- POST /_matrix/key/v2/query.
-
- Checks that each signature in the response that claims to come from the origin
- server is valid. (Does not check that there actually is such a signature, for
- some reason.)
-
- Stores the json in server_keys_json so that it can be used for future responses
- to /_matrix/key/v2/query.
-
- Args:
- from_server (str): the name of the server producing this result: either
- the origin server for a /_matrix/key/v2/server request, or the notary
- for a /_matrix/key/v2/query.
-
- response_json (dict): the json-decoded Server Keys response object
-
- time_added_ms (int): the timestamp to record in server_keys_json
-
- requested_ids (iterable[str]): a list of the key IDs that were requested.
- We will store the json for these key ids as well as any that are
- actually in the response
-
- Returns:
- Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
- """
- ts_valid_until_ms = response_json[u"valid_until_ts"]
-
- # start by extracting the keys from the response, since they may be required
- # to validate the signature on the response.
- verify_keys = {}
- for key_id, key_data in response_json["verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_base64 = key_data["key"]
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_keys[key_id] = FetchKeyResult(
- verify_key=verify_key, valid_until_ts=ts_valid_until_ms
- )
-
- # TODO: improve this signature checking
- server_name = response_json["server_name"]
- for key_id in response_json["signatures"].get(server_name, {}):
- if key_id not in verify_keys:
- raise KeyLookupError(
- "Key response must include verification keys for all signatures"
- )
-
- verify_signed_json(
- response_json, server_name, verify_keys[key_id].verify_key
- )
-
- for key_id, key_data in response_json["old_verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_base64 = key_data["key"]
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_keys[key_id] = FetchKeyResult(
- verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
- )
-
- # re-sign the json with our own key, so that it is ready if we are asked to
- # give it out as a notary server
- signed_key_json = sign_json(
- response_json, self.config.server_name, self.config.signing_key[0]
- )
-
- signed_key_json_bytes = encode_canonical_json(signed_key_json)
-
- # for reasons I don't quite understand, we store this json for the key ids we
- # requested, as well as those we got.
- updated_key_ids = set(requested_ids)
- updated_key_ids.update(verify_keys)
-
- yield logcontext.make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.store.store_server_keys_json,
- server_name=server_name,
- key_id=key_id,
- from_server=from_server,
- ts_now_ms=time_added_ms,
- ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
- )
- for key_id in updated_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
- )
-
- defer.returnValue(verify_keys)
-
@defer.inlineCallbacks
def _handle_key_deferred(verify_request):
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index eb8782aa6e..21c3c807b9 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -20,7 +20,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import KeyLookupError, ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.servlet import parse_integer, parse_json_object_from_request
@@ -89,7 +89,7 @@ class RemoteKey(Resource):
isLeaf = True
def __init__(self, hs):
- self.keyring = hs.get_keyring()
+ self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -217,7 +217,7 @@ class RemoteKey(Resource):
if cache_misses and query_remote_on_cache_miss:
for server_name, key_ids in cache_misses.items():
try:
- yield self.keyring.get_server_verify_key_v2_direct(
+ yield self.fetcher.get_server_verify_key_v2_direct(
server_name, key_ids
)
except KeyLookupError as e:
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 83de32b05d..de61bad15d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -24,7 +24,11 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import (
+ KeyLookupError,
+ PerspectivesKeyFetcher,
+ ServerKeyFetcher,
+)
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
@@ -218,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(d.called)
self.get_success(d)
+
+class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.http_client = Mock()
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ return hs
+
def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
+ fetcher = ServerKeyFetcher(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
@@ -250,7 +261,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+ keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
@@ -278,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER"
self.get_failure(
- kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
+ fetcher.get_keys(server_name_and_key_ids), KeyLookupError
)
+
+class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.mock_perspective_server = MockPerspectiveServer()
+ self.http_client = Mock()
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ keys = self.mock_perspective_server.get_verify_keys()
+ hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+ return hs
+
def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
@@ -320,7 +342,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+ keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
|