diff options
Diffstat (limited to 'tests/crypto/test_keyring.py')
-rw-r--r-- | tests/crypto/test_keyring.py | 222 |
1 files changed, 180 insertions, 42 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 3933ad4347..5a355f00cc 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,15 +19,16 @@ from mock import Mock import canonicaljson import signedjson.key import signedjson.sign +from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.crypto import keyring from synapse.crypto.keyring import ( - KeyLookupError, PerspectivesKeyFetcher, ServerKeyFetcher, + StoreKeyFetcher, ) from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext @@ -43,7 +44,7 @@ class MockPerspectiveServer(object): def get_verify_keys(self): vk = signedjson.key.get_verify_key(self.key) - return {"%s:%s" % (vk.alg, vk.version): vk} + return {"%s:%s" % (vk.alg, vk.version): encode_verify_key_base64(vk)} def get_signed_key(self, server_name, verify_key): key_id = "%s:%s" % (verify_key.alg, verify_key.version) @@ -51,9 +52,7 @@ class MockPerspectiveServer(object): "server_name": server_name, "old_verify_keys": {}, "valid_until_ts": time.time() * 1000 + 3600, - "verify_keys": { - key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)} - }, + "verify_keys": {key_id: {"key": encode_verify_key_base64(verify_key)}}, } self.sign_response(res) return res @@ -66,10 +65,18 @@ class KeyringTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() - hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) - keys = self.mock_perspective_server.get_verify_keys() - hs.config.perspectives = {self.mock_perspective_server.server_name: keys} - return hs + + config = self.default_config() + config["trusted_key_servers"] = [ + { + "server_name": self.mock_perspective_server.server_name, + "verify_keys": self.mock_perspective_server.get_verify_keys(), + } + ] + + return self.setup_test_homeserver( + handlers=None, http_client=self.http_client, config=config + ) def check_context(self, _, expected): self.assertEquals( @@ -137,7 +144,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): context_11.request = "11" res_deferreds = kr.verify_json_objects_for_server( - [("server10", json1), ("server11", {})] + [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] ) # the unsigned json should be rejected pretty quickly @@ -174,7 +181,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.http_client.post_json.return_value = defer.Deferred() res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1)] + [("server10", json1, 0, "test")] ) res_deferreds_2[0].addBoth(self.check_context, None) yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) @@ -197,31 +204,152 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - key1_id = "%s:%s" % (key1.alg, key1.version) + r = self.hs.datastore.store_server_verify_keys( + "server9", + time.time() * 1000, + [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], + ) + self.get_success(r) + + json1 = {} + signedjson.sign.sign_json(json1, "server9", key1) + + # should fail immediately on an unsigned object + d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") + self.failureResultOf(d, SynapseError) + + # should suceed on a signed object + d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") + # self.assertFalse(d.called) + self.get_success(d) + def test_verify_json_for_server_with_null_valid_until_ms(self): + """Tests that we correctly handle key requests for keys we've stored + with a null `ts_valid_until_ms` + """ + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock(return_value=defer.succeed({})) + + kr = keyring.Keyring( + self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) + ) + + key1 = signedjson.key.generate_signing_key(1) r = self.hs.datastore.store_server_verify_keys( "server9", time.time() * 1000, - [ - ( - "server9", - key1_id, - FetchKeyResult(signedjson.key.get_verify_key(key1), 1000), - ), - ], + [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], ) self.get_success(r) + json1 = {} signedjson.sign.sign_json(json1, "server9", key1) # should fail immediately on an unsigned object - d = _verify_json_for_server(kr, "server9", {}) + d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") self.failureResultOf(d, SynapseError) - d = _verify_json_for_server(kr, "server9", json1) - self.assertFalse(d.called) + # should fail on a signed object with a non-zero minimum_valid_until_ms, + # as it tries to refetch the keys and fails. + d = _verify_json_for_server( + kr, "server9", json1, 500, "test signed non-zero min" + ) + self.get_failure(d, SynapseError) + + # We expect the keyring tried to refetch the key once. + mock_fetcher.get_keys.assert_called_once_with( + {"server9": {get_key_id(key1): 500}} + ) + + # should succeed on a signed object with a 0 minimum_valid_until_ms + d = _verify_json_for_server( + kr, "server9", json1, 0, "test signed with zero min" + ) self.get_success(d) + def test_verify_json_dedupes_key_requests(self): + """Two requests for the same key should be deduped.""" + key1 = signedjson.key.generate_signing_key(1) + + def get_keys(keys_to_fetch): + # there should only be one request object (with the max validity) + self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + + return defer.succeed( + { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) + } + } + ) + + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock(side_effect=get_keys) + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) + + json1 = {} + signedjson.sign.sign_json(json1, "server1", key1) + + # the first request should succeed; the second should fail because the key + # has expired + results = kr.verify_json_objects_for_server( + [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")] + ) + self.assertEqual(len(results), 2) + self.get_success(results[0]) + e = self.get_failure(results[1], SynapseError).value + self.assertEqual(e.errcode, "M_UNAUTHORIZED") + self.assertEqual(e.code, 401) + + # there should have been a single call to the fetcher + mock_fetcher.get_keys.assert_called_once() + + def test_verify_json_falls_back_to_other_fetchers(self): + """If the first fetcher cannot provide a recent enough key, we fall back""" + key1 = signedjson.key.generate_signing_key(1) + + def get_keys1(keys_to_fetch): + self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + return defer.succeed( + { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800) + } + } + ) + + def get_keys2(keys_to_fetch): + self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + return defer.succeed( + { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) + } + } + ) + + mock_fetcher1 = keyring.KeyFetcher() + mock_fetcher1.get_keys = Mock(side_effect=get_keys1) + mock_fetcher2 = keyring.KeyFetcher() + mock_fetcher2.get_keys = Mock(side_effect=get_keys2) + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) + + json1 = {} + signedjson.sign.sign_json(json1, "server1", key1) + + results = kr.verify_json_objects_for_server( + [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")] + ) + self.assertEqual(len(results), 2) + self.get_success(results[0]) + e = self.get_failure(results[1], SynapseError).value + self.assertEqual(e.errcode, "M_UNAUTHORIZED") + self.assertEqual(e.code, 401) + + # there should have been a single call to each fetcher + mock_fetcher1.get_keys.assert_called_once() + mock_fetcher2.get_keys.assert_called_once() + class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): @@ -260,8 +388,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): self.http_client.get_json.side_effect = get_json - server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) k = keys[SERVER_NAME][testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) @@ -286,21 +414,29 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) ) - # change the server name: it should cause a rejection + # change the server name: the result should be ignored response["server_name"] = "OTHER_SERVER" - self.get_failure( - fetcher.get_keys(server_name_and_key_ids), KeyLookupError - ) + + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) + self.assertEqual(keys, {}) class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() - hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) - keys = self.mock_perspective_server.get_verify_keys() - hs.config.perspectives = {self.mock_perspective_server.server_name: keys} - return hs + + config = self.default_config() + config["trusted_key_servers"] = [ + { + "server_name": self.mock_perspective_server.server_name, + "verify_keys": self.mock_perspective_server.get_verify_keys(), + } + ] + + return self.setup_test_homeserver( + handlers=None, http_client=self.http_client, config=config + ) def test_get_keys_from_perspectives(self): # arbitrarily advance the clock a bit @@ -342,8 +478,8 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.http_client.post_json.side_effect = post_json - server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) self.assertIn(SERVER_NAME, keys) k = keys[SERVER_NAME][testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) @@ -365,8 +501,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) self.assertEqual( - bytes(res["key_json"]), - canonicaljson.encode_canonical_json(response), + bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) ) def test_invalid_perspectives_responses(self): @@ -401,7 +536,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def get_key_from_perspectives(response): fetcher = PerspectivesKeyFetcher(self.hs) - server_name_and_key_ids = [(SERVER_NAME, ("key1",))] + keys_to_fetch = {SERVER_NAME: {"key1": 0}} def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) @@ -410,9 +545,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.http_client.post_json.side_effect = post_json - return self.get_success( - fetcher.get_keys(server_name_and_key_ids) - ) + return self.get_success(fetcher.get_keys(keys_to_fetch)) # start with a valid response so we can check we are testing the right thing response = build_response() @@ -435,6 +568,11 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") +def get_key_id(key): + """Get the matrix ID tag for a given SigningKey or VerifyKey""" + return "%s:%s" % (key.alg, key.version) + + @defer.inlineCallbacks def run_in_context(f, *args, **kwargs): with LoggingContext("testctx") as ctx: @@ -445,14 +583,14 @@ def run_in_context(f, *args, **kwargs): defer.returnValue(rv) -def _verify_json_for_server(keyring, server_name, json_object): +def _verify_json_for_server(kr, *args): """thin wrapper around verify_json_for_server which makes sure it is wrapped with the patched defer.inlineCallbacks. """ @defer.inlineCallbacks def v(): - rv1 = yield keyring.verify_json_for_server(server_name, json_object) + rv1 = yield kr.verify_json_for_server(*args) defer.returnValue(rv1) return run_in_context(v) |