diff options
-rw-r--r-- | changelog.d/5001.misc | 1 | ||||
-rw-r--r-- | synapse/crypto/keyring.py | 79 | ||||
-rw-r--r-- | synapse/storage/keys.py | 4 | ||||
-rw-r--r-- | tests/crypto/test_keyring.py | 130 |
4 files changed, 181 insertions, 33 deletions
diff --git a/changelog.d/5001.misc b/changelog.d/5001.misc new file mode 100644 index 0000000000..bf590a016d --- /dev/null +++ b/changelog.d/5001.misc @@ -0,0 +1 @@ +Clean up some code in the server-key Keyring. \ No newline at end of file diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 0207cd989a..54af60d711 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -20,6 +20,7 @@ from collections import namedtuple from six import raise_from from six.moves import urllib +import nacl.signing from signedjson.key import ( decode_verify_key_bytes, encode_verify_key_base64, @@ -494,11 +495,11 @@ class Keyring(object): ) processed_response = yield self.process_v2_response( - perspective_name, response, only_from_server=False + perspective_name, response ) + server_name = response["server_name"] - for server_name, response_keys in processed_response.items(): - keys.setdefault(server_name, {}).update(response_keys) + keys.setdefault(server_name, {}).update(processed_response) yield logcontext.make_deferred_yieldable(defer.gatherResults( [ @@ -517,7 +518,7 @@ class Keyring(object): @defer.inlineCallbacks def get_server_verify_key_v2_direct(self, server_name, key_ids): - keys = {} + keys = {} # type: dict[str, nacl.signing.VerifyKey] for requested_key_id in key_ids: if requested_key_id in keys: @@ -542,6 +543,11 @@ class Keyring(object): or server_name not in response[u"signatures"]): raise KeyLookupError("Key response not signed by remote server") + if response["server_name"] != server_name: + raise KeyLookupError("Expected a response for server %r not %r" % ( + server_name, response["server_name"] + )) + response_keys = yield self.process_v2_response( from_server=server_name, requested_ids=[requested_key_id], @@ -550,24 +556,45 @@ class Keyring(object): keys.update(response_keys) - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - run_in_background( - self.store_keys, - server_name=key_server_name, - from_server=server_name, - verify_keys=verify_keys, - ) - for key_server_name, verify_keys in keys.items() - ], - consumeErrors=True - ).addErrback(unwrapFirstError)) - - defer.returnValue(keys) + yield self.store_keys( + server_name=server_name, + from_server=server_name, + verify_keys=keys, + ) + defer.returnValue({server_name: keys}) @defer.inlineCallbacks - def process_v2_response(self, from_server, response_json, - requested_ids=[], only_from_server=True): + def process_v2_response( + self, from_server, response_json, requested_ids=[], + ): + """Parse a 'Server Keys' structure from the result of a /key request + + This is used to parse either the entirety of the response from + GET /_matrix/key/v2/server, or a single entry from the list returned by + POST /_matrix/key/v2/query. + + Checks that each signature in the response that claims to come from the origin + server is valid. (Does not check that there actually is such a signature, for + some reason.) + + Stores the json in server_keys_json so that it can be used for future responses + to /_matrix/key/v2/query. + + Args: + from_server (str): the name of the server producing this result: either + the origin server for a /_matrix/key/v2/server request, or the notary + for a /_matrix/key/v2/query. + + response_json (dict): the json-decoded Server Keys response object + + requested_ids (iterable[str]): a list of the key IDs that were requested. + We will store the json for these key ids as well as any that are + actually in the response + + Returns: + Deferred[dict[str, nacl.signing.VerifyKey]]: + map from key_id to key object + """ time_now_ms = self.clock.time_msec() response_keys = {} verify_keys = {} @@ -589,15 +616,7 @@ class Keyring(object): verify_key.time_added = time_now_ms old_verify_keys[key_id] = verify_key - results = {} server_name = response_json["server_name"] - if only_from_server: - if server_name != from_server: - raise KeyLookupError( - "Expected a response for server %r not %r" % ( - from_server, server_name - ) - ) for key_id in response_json["signatures"].get(server_name, {}): if key_id not in response_json["verify_keys"]: raise KeyLookupError( @@ -643,9 +662,7 @@ class Keyring(object): consumeErrors=True, ).addErrback(unwrapFirstError)) - results[server_name] = response_keys - - defer.returnValue(results) + defer.returnValue(response_keys) def store_keys(self, server_name, from_server, verify_keys): """Store a collection of verify keys for a given server diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 030cd1e5a3..f24ab3eedd 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -188,8 +188,8 @@ class KeyStore(SQLBaseStore): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Dict mapping (server_name, key_id, source) triplets to dicts with - "ts_valid_until_ms" and "key_json" keys. + Deferred[dict[Tuple[str, str, str|None], list[dict]]]: + Dict mapping (server_name, key_id, source) triplets to lists of dicts """ def _get_server_keys_json_txn(txn): diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index c30a1a69e7..b224fdb23a 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -16,6 +16,7 @@ import time from mock import Mock +import canonicaljson import signedjson.key import signedjson.sign @@ -23,6 +24,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.crypto import keyring +from synapse.crypto.keyring import KeyLookupError from synapse.util import logcontext from synapse.util.logcontext import LoggingContext @@ -48,6 +50,9 @@ class MockPerspectiveServer(object): key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)} }, } + return self.get_signed_response(res) + + def get_signed_response(self, res): signedjson.sign.sign_json(res, self.server_name, self.key) return res @@ -202,6 +207,131 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertFalse(d.called) self.get_success(d) + def test_get_keys_from_server(self): + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + SERVER_NAME = "server2" + kr = keyring.Keyring(self.hs) + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 1000 + + # valid response + response = { + "server_name": SERVER_NAME, + "old_verify_keys": {}, + "valid_until_ts": VALID_UNTIL_TS, + "verify_keys": { + testverifykey_id: { + "key": signedjson.key.encode_verify_key_base64(testverifykey) + } + }, + } + signedjson.sign.sign_json(response, SERVER_NAME, testkey) + + def get_json(destination, path, **kwargs): + self.assertEqual(destination, SERVER_NAME) + self.assertEqual(path, "/_matrix/key/v2/server/key1") + return response + + self.http_client.get_json.side_effect = get_json + + server_name_and_key_ids = [(SERVER_NAME, ("key1",))] + keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) + k = keys[SERVER_NAME][testverifykey_id] + self.assertEqual(k, testverifykey) + self.assertEqual(k.alg, "ed25519") + self.assertEqual(k.version, "ver1") + + # check that the perspectives store is correctly updated + lookup_triplet = (SERVER_NAME, testverifykey_id, None) + key_json = self.get_success( + self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + ) + res = key_json[lookup_triplet] + self.assertEqual(len(res), 1) + res = res[0] + self.assertEqual(res["key_id"], testverifykey_id) + self.assertEqual(res["from_server"], SERVER_NAME) + self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) + self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + + # we expect it to be encoded as canonical json *before* it hits the db + self.assertEqual( + bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) + ) + + # change the server name: it should cause a rejection + response["server_name"] = "OTHER_SERVER" + self.get_failure( + kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError + ) + + def test_get_keys_from_perspectives(self): + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + SERVER_NAME = "server2" + kr = keyring.Keyring(self.hs) + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + # valid response + response = { + "server_name": SERVER_NAME, + "old_verify_keys": {}, + "valid_until_ts": VALID_UNTIL_TS, + "verify_keys": { + testverifykey_id: { + "key": signedjson.key.encode_verify_key_base64(testverifykey) + } + }, + } + + persp_resp = { + "server_keys": [self.mock_perspective_server.get_signed_response(response)] + } + + def post_json(destination, path, data, **kwargs): + self.assertEqual(destination, self.mock_perspective_server.server_name) + self.assertEqual(path, "/_matrix/key/v2/query") + + # check that the request is for the expected key + q = data["server_keys"] + self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"]) + return persp_resp + + self.http_client.post_json.side_effect = post_json + + server_name_and_key_ids = [(SERVER_NAME, ("key1",))] + keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) + self.assertIn(SERVER_NAME, keys) + k = keys[SERVER_NAME][testverifykey_id] + self.assertEqual(k, testverifykey) + self.assertEqual(k.alg, "ed25519") + self.assertEqual(k.version, "ver1") + + # check that the perspectives store is correctly updated + lookup_triplet = (SERVER_NAME, testverifykey_id, None) + key_json = self.get_success( + self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + ) + res = key_json[lookup_triplet] + self.assertEqual(len(res), 1) + res = res[0] + self.assertEqual(res["key_id"], testverifykey_id) + self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) + self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + + self.assertEqual( + bytes(res["key_json"]), + canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]), + ) + @defer.inlineCallbacks def run_in_context(f, *args, **kwargs): |