diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3c79d4afe7..3933ad4347 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -24,7 +24,12 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import (
+ KeyLookupError,
+ PerspectivesKeyFetcher,
+ ServerKeyFetcher,
+)
+from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
@@ -50,11 +55,11 @@ class MockPerspectiveServer(object):
key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
},
}
- return self.get_signed_response(res)
+ self.sign_response(res)
+ return res
- def get_signed_response(self, res):
+ def sign_response(self, res):
signedjson.sign.sign_json(res, self.server_name, self.key)
- return res
class KeyringTestCase(unittest.HomeserverTestCase):
@@ -80,7 +85,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# we run the lookup in a logcontext so that the patched inlineCallbacks can check
# it is doing the right thing with logcontexts.
wait_1_deferred = run_in_context(
- kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred}
+ kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
)
# there were no previous lookups, so the deferred should be ready
@@ -89,7 +94,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = run_in_context(
- kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred}
+ kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
@@ -192,8 +197,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_key(
- "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
+ key1_id = "%s:%s" % (key1.alg, key1.version)
+
+ r = self.hs.datastore.store_server_verify_keys(
+ "server9",
+ time.time() * 1000,
+ [
+ (
+ "server9",
+ key1_id,
+ FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
+ ),
+ ],
)
self.get_success(r)
json1 = {}
@@ -207,16 +222,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(d.called)
self.get_success(d)
+
+class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.http_client = Mock()
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ return hs
+
def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
+ fetcher = ServerKeyFetcher(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
- VALID_UNTIL_TS = 1000
+ VALID_UNTIL_TS = 200 * 1000
# valid response
response = {
@@ -239,11 +261,12 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+ keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id]
- self.assertEqual(k, testverifykey)
- self.assertEqual(k.alg, "ed25519")
- self.assertEqual(k.version, "ver1")
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -266,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER"
self.get_failure(
- kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
+ fetcher.get_keys(server_name_and_key_ids), KeyLookupError
)
+
+class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.mock_perspective_server = MockPerspectiveServer()
+ self.http_client = Mock()
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ keys = self.mock_perspective_server.get_verify_keys()
+ hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+ return hs
+
def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
@@ -292,9 +326,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
},
}
- persp_resp = {
- "server_keys": [self.mock_perspective_server.get_signed_response(response)]
- }
+ # the response must be signed by both the origin server and the perspectives
+ # server.
+ signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ self.mock_perspective_server.sign_response(response)
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -303,17 +338,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# check that the request is for the expected key
q = data["server_keys"]
self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
- return persp_resp
+ return {"server_keys": [response]}
self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+ keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
- self.assertEqual(k, testverifykey)
- self.assertEqual(k.alg, "ed25519")
- self.assertEqual(k.version, "ver1")
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -330,13 +366,81 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertEqual(
bytes(res["key_json"]),
- canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
+ canonicaljson.encode_canonical_json(response),
)
+ def test_invalid_perspectives_responses(self):
+ """Check that invalid responses from the perspectives server are rejected"""
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ SERVER_NAME = "server2"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ def build_response():
+ # valid response
+ response = {
+ "server_name": SERVER_NAME,
+ "old_verify_keys": {},
+ "valid_until_ts": VALID_UNTIL_TS,
+ "verify_keys": {
+ testverifykey_id: {
+ "key": signedjson.key.encode_verify_key_base64(testverifykey)
+ }
+ },
+ }
+
+ # the response must be signed by both the origin server and the perspectives
+ # server.
+ signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ self.mock_perspective_server.sign_response(response)
+ return response
+
+ def get_key_from_perspectives(response):
+ fetcher = PerspectivesKeyFetcher(self.hs)
+ server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+
+ def post_json(destination, path, data, **kwargs):
+ self.assertEqual(destination, self.mock_perspective_server.server_name)
+ self.assertEqual(path, "/_matrix/key/v2/query")
+ return {"server_keys": [response]}
+
+ self.http_client.post_json.side_effect = post_json
+
+ return self.get_success(
+ fetcher.get_keys(server_name_and_key_ids)
+ )
+
+ # start with a valid response so we can check we are testing the right thing
+ response = build_response()
+ keys = get_key_from_perspectives(response)
+ k = keys[SERVER_NAME][testverifykey_id]
+ self.assertEqual(k.verify_key, testverifykey)
+
+ # remove the perspectives server's signature
+ response = build_response()
+ del response["signatures"][self.mock_perspective_server.server_name]
+ self.http_client.post_json.return_value = {"server_keys": [response]}
+ keys = get_key_from_perspectives(response)
+ self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
+
+ # remove the origin server's signature
+ response = build_response()
+ del response["signatures"][SERVER_NAME]
+ self.http_client.post_json.return_value = {"server_keys": [response]}
+ keys = get_key_from_perspectives(response)
+ self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
+
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
- with LoggingContext("testctx"):
+ with LoggingContext("testctx") as ctx:
+ # we set the "request" prop to make it easier to follow what's going on in the
+ # logs.
+ ctx.request = "testctx"
rv = yield f(*args, **kwargs)
defer.returnValue(rv)
|