diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index eaf41b983c..a64ba0752a 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -394,8 +394,7 @@ class BaseV2KeyFetcher(object):
POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin
- server is valid. (Does not check that there actually is such a signature, for
- some reason.)
+ server is valid, and that there is at least one such signature.
Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.
@@ -430,16 +429,25 @@ class BaseV2KeyFetcher(object):
verify_key=verify_key, valid_until_ts=ts_valid_until_ms
)
- # TODO: improve this signature checking
server_name = response_json["server_name"]
+ verified = False
for key_id in response_json["signatures"].get(server_name, {}):
- if key_id not in verify_keys:
+ # each of the keys used for the signature must be present in the response
+ # json.
+ key = verify_keys.get(key_id)
+ if not key:
raise KeyLookupError(
- "Key response must include verification keys for all signatures"
+ "Key response is signed by key id %s:%s but that key is not "
+ "present in the response" % (server_name, key_id)
)
- verify_signed_json(
- response_json, server_name, verify_keys[key_id].verify_key
+ verify_signed_json(response_json, server_name, key.verify_key)
+ verified = True
+
+ if not verified:
+ raise KeyLookupError(
+ "Key response for %s is not signed by the origin server"
+ % (server_name,)
)
for key_id, key_data in response_json["old_verify_keys"].items():
@@ -677,12 +685,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise_from(KeyLookupError("Remote server returned an error"), e)
- if (
- u"signatures" not in response
- or server_name not in response[u"signatures"]
- ):
- raise KeyLookupError("Key response not signed by remote server")
-
if response["server_name"] != server_name:
raise KeyLookupError(
"Expected a response for server %r not %r"
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index de61bad15d..c4c9d29499 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -55,11 +55,11 @@ class MockPerspectiveServer(object):
key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
},
}
- return self.get_signed_response(res)
+ self.sign_response(res)
+ return res
- def get_signed_response(self, res):
+ def sign_response(self, res):
signedjson.sign.sign_json(res, self.server_name, self.key)
- return res
class KeyringTestCase(unittest.HomeserverTestCase):
@@ -238,7 +238,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
- VALID_UNTIL_TS = 1000
+ VALID_UNTIL_TS = 200 * 1000
# valid response
response = {
@@ -326,9 +326,10 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
},
}
- persp_resp = {
- "server_keys": [self.mock_perspective_server.get_signed_response(response)]
- }
+ # the response must be signed by both the origin server and the perspectives
+ # server.
+ signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ self.mock_perspective_server.sign_response(response)
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -337,7 +338,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# check that the request is for the expected key
q = data["server_keys"]
self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
- return persp_resp
+ return {"server_keys": [response]}
self.http_client.post_json.side_effect = post_json
@@ -365,9 +366,74 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(
bytes(res["key_json"]),
- canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
+ canonicaljson.encode_canonical_json(response),
)
+ def test_invalid_perspectives_responses(self):
+ """Check that invalid responses from the perspectives server are rejected"""
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ SERVER_NAME = "server2"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ def build_response():
+ # valid response
+ response = {
+ "server_name": SERVER_NAME,
+ "old_verify_keys": {},
+ "valid_until_ts": VALID_UNTIL_TS,
+ "verify_keys": {
+ testverifykey_id: {
+ "key": signedjson.key.encode_verify_key_base64(testverifykey)
+ }
+ },
+ }
+
+ # the response must be signed by both the origin server and the perspectives
+ # server.
+ signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ self.mock_perspective_server.sign_response(response)
+ return response
+
+ def get_key_from_perspectives(response):
+ fetcher = PerspectivesKeyFetcher(self.hs)
+ server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+
+ def post_json(destination, path, data, **kwargs):
+ self.assertEqual(destination, self.mock_perspective_server.server_name)
+ self.assertEqual(path, "/_matrix/key/v2/query")
+ return {"server_keys": [response]}
+
+ self.http_client.post_json.side_effect = post_json
+
+ return self.get_success(
+ fetcher.get_keys(server_name_and_key_ids)
+ )
+
+ # start with a valid response so we can check we are testing the right thing
+ response = build_response()
+ keys = get_key_from_perspectives(response)
+ k = keys[SERVER_NAME][testverifykey_id]
+ self.assertEqual(k.verify_key, testverifykey)
+
+ # remove the perspectives server's signature
+ response = build_response()
+ del response["signatures"][self.mock_perspective_server.server_name]
+ self.http_client.post_json.return_value = {"server_keys": [response]}
+ keys = get_key_from_perspectives(response)
+ self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
+
+ # remove the origin server's signature
+ response = build_response()
+ del response["signatures"][SERVER_NAME]
+ self.http_client.post_json.return_value = {"server_keys": [response]}
+ keys = get_key_from_perspectives(response)
+ self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
+
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
|