diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3933ad4347..5a355f00cc 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,15 +19,16 @@ from mock import Mock
import canonicaljson
import signedjson.key
import signedjson.sign
+from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
from synapse.crypto.keyring import (
- KeyLookupError,
PerspectivesKeyFetcher,
ServerKeyFetcher,
+ StoreKeyFetcher,
)
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext
@@ -43,7 +44,7 @@ class MockPerspectiveServer(object):
def get_verify_keys(self):
vk = signedjson.key.get_verify_key(self.key)
- return {"%s:%s" % (vk.alg, vk.version): vk}
+ return {"%s:%s" % (vk.alg, vk.version): encode_verify_key_base64(vk)}
def get_signed_key(self, server_name, verify_key):
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
@@ -51,9 +52,7 @@ class MockPerspectiveServer(object):
"server_name": server_name,
"old_verify_keys": {},
"valid_until_ts": time.time() * 1000 + 3600,
- "verify_keys": {
- key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
- },
+ "verify_keys": {key_id: {"key": encode_verify_key_base64(verify_key)}},
}
self.sign_response(res)
return res
@@ -66,10 +65,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
- hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
- keys = self.mock_perspective_server.get_verify_keys()
- hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
- return hs
+
+ config = self.default_config()
+ config["trusted_key_servers"] = [
+ {
+ "server_name": self.mock_perspective_server.server_name,
+ "verify_keys": self.mock_perspective_server.get_verify_keys(),
+ }
+ ]
+
+ return self.setup_test_homeserver(
+ handlers=None, http_client=self.http_client, config=config
+ )
def check_context(self, _, expected):
self.assertEquals(
@@ -137,7 +144,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
context_11.request = "11"
res_deferreds = kr.verify_json_objects_for_server(
- [("server10", json1), ("server11", {})]
+ [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
)
# the unsigned json should be rejected pretty quickly
@@ -174,7 +181,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server(
- [("server10", json1)]
+ [("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@@ -197,31 +204,152 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- key1_id = "%s:%s" % (key1.alg, key1.version)
+ r = self.hs.datastore.store_server_verify_keys(
+ "server9",
+ time.time() * 1000,
+ [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
+ )
+ self.get_success(r)
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server9", key1)
+
+ # should fail immediately on an unsigned object
+ d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
+ self.failureResultOf(d, SynapseError)
+
+ # should suceed on a signed object
+ d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
+ # self.assertFalse(d.called)
+ self.get_success(d)
+ def test_verify_json_for_server_with_null_valid_until_ms(self):
+ """Tests that we correctly handle key requests for keys we've stored
+ with a null `ts_valid_until_ms`
+ """
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+
+ kr = keyring.Keyring(
+ self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
+ )
+
+ key1 = signedjson.key.generate_signing_key(1)
r = self.hs.datastore.store_server_verify_keys(
"server9",
time.time() * 1000,
- [
- (
- "server9",
- key1_id,
- FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
- ),
- ],
+ [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
)
self.get_success(r)
+
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
- d = _verify_json_for_server(kr, "server9", {})
+ d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
self.failureResultOf(d, SynapseError)
- d = _verify_json_for_server(kr, "server9", json1)
- self.assertFalse(d.called)
+ # should fail on a signed object with a non-zero minimum_valid_until_ms,
+ # as it tries to refetch the keys and fails.
+ d = _verify_json_for_server(
+ kr, "server9", json1, 500, "test signed non-zero min"
+ )
+ self.get_failure(d, SynapseError)
+
+ # We expect the keyring tried to refetch the key once.
+ mock_fetcher.get_keys.assert_called_once_with(
+ {"server9": {get_key_id(key1): 500}}
+ )
+
+ # should succeed on a signed object with a 0 minimum_valid_until_ms
+ d = _verify_json_for_server(
+ kr, "server9", json1, 0, "test signed with zero min"
+ )
self.get_success(d)
+ def test_verify_json_dedupes_key_requests(self):
+ """Two requests for the same key should be deduped."""
+ key1 = signedjson.key.generate_signing_key(1)
+
+ def get_keys(keys_to_fetch):
+ # there should only be one request object (with the max validity)
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+ }
+ }
+ )
+
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock(side_effect=get_keys)
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server1", key1)
+
+ # the first request should succeed; the second should fail because the key
+ # has expired
+ results = kr.verify_json_objects_for_server(
+ [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
+ )
+ self.assertEqual(len(results), 2)
+ self.get_success(results[0])
+ e = self.get_failure(results[1], SynapseError).value
+ self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+ self.assertEqual(e.code, 401)
+
+ # there should have been a single call to the fetcher
+ mock_fetcher.get_keys.assert_called_once()
+
+ def test_verify_json_falls_back_to_other_fetchers(self):
+ """If the first fetcher cannot provide a recent enough key, we fall back"""
+ key1 = signedjson.key.generate_signing_key(1)
+
+ def get_keys1(keys_to_fetch):
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
+ }
+ }
+ )
+
+ def get_keys2(keys_to_fetch):
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+ }
+ }
+ )
+
+ mock_fetcher1 = keyring.KeyFetcher()
+ mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
+ mock_fetcher2 = keyring.KeyFetcher()
+ mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server1", key1)
+
+ results = kr.verify_json_objects_for_server(
+ [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
+ )
+ self.assertEqual(len(results), 2)
+ self.get_success(results[0])
+ e = self.get_failure(results[1], SynapseError).value
+ self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+ self.assertEqual(e.code, 401)
+
+ # there should have been a single call to each fetcher
+ mock_fetcher1.get_keys.assert_called_once()
+ mock_fetcher2.get_keys.assert_called_once()
+
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
@@ -260,8 +388,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
@@ -286,21 +414,29 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
- # change the server name: it should cause a rejection
+ # change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
- self.get_failure(
- fetcher.get_keys(server_name_and_key_ids), KeyLookupError
- )
+
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+ self.assertEqual(keys, {})
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
- hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
- keys = self.mock_perspective_server.get_verify_keys()
- hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
- return hs
+
+ config = self.default_config()
+ config["trusted_key_servers"] = [
+ {
+ "server_name": self.mock_perspective_server.server_name,
+ "verify_keys": self.mock_perspective_server.get_verify_keys(),
+ }
+ ]
+
+ return self.setup_test_homeserver(
+ handlers=None, http_client=self.http_client, config=config
+ )
def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit
@@ -342,8 +478,8 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
@@ -365,8 +501,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
self.assertEqual(
- bytes(res["key_json"]),
- canonicaljson.encode_canonical_json(response),
+ bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
def test_invalid_perspectives_responses(self):
@@ -401,7 +536,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs)
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -410,9 +545,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json
- return self.get_success(
- fetcher.get_keys(server_name_and_key_ids)
- )
+ return self.get_success(fetcher.get_keys(keys_to_fetch))
# start with a valid response so we can check we are testing the right thing
response = build_response()
@@ -435,6 +568,11 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
+def get_key_id(key):
+ """Get the matrix ID tag for a given SigningKey or VerifyKey"""
+ return "%s:%s" % (key.alg, key.version)
+
+
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx") as ctx:
@@ -445,14 +583,14 @@ def run_in_context(f, *args, **kwargs):
defer.returnValue(rv)
-def _verify_json_for_server(keyring, server_name, json_object):
+def _verify_json_for_server(kr, *args):
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
with the patched defer.inlineCallbacks.
"""
@defer.inlineCallbacks
def v():
- rv1 = yield keyring.verify_json_for_server(server_name, json_object)
+ rv1 = yield kr.verify_json_for_server(*args)
defer.returnValue(rv1)
return run_in_context(v)
|