diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 5a355f00cc..34d5895f18 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,6 +19,7 @@ from mock import Mock
import canonicaljson
import signedjson.key
import signedjson.sign
+from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
@@ -30,9 +31,12 @@ from synapse.crypto.keyring import (
ServerKeyFetcher,
StoreKeyFetcher,
)
+from synapse.logging.context import (
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+)
from synapse.storage.keys import FetchKeyResult
-from synapse.util import logcontext
-from synapse.util.logcontext import LoggingContext
from tests import unittest
@@ -83,35 +87,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
getattr(LoggingContext.current_context(), "request", None), expected
)
- def test_wait_for_previous_lookups(self):
- kr = keyring.Keyring(self.hs)
-
- lookup_1_deferred = defer.Deferred()
- lookup_2_deferred = defer.Deferred()
-
- # we run the lookup in a logcontext so that the patched inlineCallbacks can check
- # it is doing the right thing with logcontexts.
- wait_1_deferred = run_in_context(
- kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
- )
-
- # there were no previous lookups, so the deferred should be ready
- self.successResultOf(wait_1_deferred)
-
- # set off another wait. It should block because the first lookup
- # hasn't yet completed.
- wait_2_deferred = run_in_context(
- kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
- )
-
- self.assertFalse(wait_2_deferred.called)
-
- # let the first lookup complete (in the sentinel context)
- lookup_1_deferred.callback(None)
-
- # now the second wait should complete.
- self.successResultOf(wait_2_deferred)
-
def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1)
@@ -131,9 +106,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def get_perspectives(**kwargs):
self.assertEquals(LoggingContext.current_context().request, "11")
- with logcontext.PreserveLoggingContext():
+ with PreserveLoggingContext():
yield persp_deferred
- defer.returnValue(persp_resp)
+ return persp_resp
self.http_client.post_json.side_effect = get_perspectives
@@ -158,7 +133,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
- yield logcontext.make_deferred_yieldable(res_deferreds[0])
+ yield make_deferred_yieldable(res_deferreds[0])
# let verify_json_objects_for_server finish its work before we kill the
# logcontext
@@ -184,7 +159,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
[("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
- yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
+ yield make_deferred_yieldable(res_deferreds_2[0])
# let verify_json_objects_for_server finish its work before we kill the
# logcontext
@@ -204,7 +179,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
@@ -235,7 +210,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_keys(
+ r = self.hs.get_datastore().store_server_verify_keys(
"server9",
time.time() * 1000,
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
@@ -438,34 +413,37 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
handlers=None, http_client=self.http_client, config=config
)
- def test_get_keys_from_perspectives(self):
- # arbitrarily advance the clock a bit
- self.reactor.advance(100)
-
- fetcher = PerspectivesKeyFetcher(self.hs)
-
- SERVER_NAME = "server2"
- testkey = signedjson.key.generate_signing_key("ver1")
- testverifykey = signedjson.key.get_verify_key(testkey)
- testverifykey_id = "ed25519:ver1"
- VALID_UNTIL_TS = 200 * 1000
+ def build_perspectives_response(
+ self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
+ ) -> dict:
+ """
+ Build a valid perspectives server response to a request for the given key
+ """
+ verify_key = signedjson.key.get_verify_key(signing_key)
+ verifykey_id = "%s:%s" % (verify_key.alg, verify_key.version)
- # valid response
response = {
- "server_name": SERVER_NAME,
+ "server_name": server_name,
"old_verify_keys": {},
- "valid_until_ts": VALID_UNTIL_TS,
+ "valid_until_ts": valid_until_ts,
"verify_keys": {
- testverifykey_id: {
- "key": signedjson.key.encode_verify_key_base64(testverifykey)
+ verifykey_id: {
+ "key": signedjson.key.encode_verify_key_base64(verify_key)
}
},
}
-
# the response must be signed by both the origin server and the perspectives
# server.
- signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ signedjson.sign.sign_json(response, server_name, signing_key)
self.mock_perspective_server.sign_response(response)
+ return response
+
+ def expect_outgoing_key_query(
+ self, expected_server_name: str, expected_key_id: str, response: dict
+ ) -> None:
+ """
+ Tell the mock http client to expect a perspectives-server key query
+ """
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -473,11 +451,79 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# check that the request is for the expected key
q = data["server_keys"]
- self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
+ self.assertEqual(list(q[expected_server_name].keys()), [expected_key_id])
return {"server_keys": [response]}
self.http_client.post_json.side_effect = post_json
+ def test_get_keys_from_perspectives(self):
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
+ SERVER_NAME = "server2"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ response = self.build_perspectives_response(
+ SERVER_NAME, testkey, VALID_UNTIL_TS,
+ )
+
+ self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+ self.assertIn(SERVER_NAME, keys)
+ k = keys[SERVER_NAME][testverifykey_id]
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
+
+ # check that the perspectives store is correctly updated
+ lookup_triplet = (SERVER_NAME, testverifykey_id, None)
+ key_json = self.get_success(
+ self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+ )
+ res = key_json[lookup_triplet]
+ self.assertEqual(len(res), 1)
+ res = res[0]
+ self.assertEqual(res["key_id"], testverifykey_id)
+ self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
+ self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+ self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+
+ self.assertEqual(
+ bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
+ )
+
+ def test_get_perspectives_own_key(self):
+ """Check that we can get the perspectives server's own keys
+
+ This is slightly complicated by the fact that the perspectives server may
+ use different keys for signing notary responses.
+ """
+
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
+ SERVER_NAME = self.mock_perspective_server.server_name
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ response = self.build_perspectives_response(
+ SERVER_NAME, testkey, VALID_UNTIL_TS
+ )
+
+ self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
@@ -516,35 +562,14 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
VALID_UNTIL_TS = 200 * 1000
def build_response():
- # valid response
- response = {
- "server_name": SERVER_NAME,
- "old_verify_keys": {},
- "valid_until_ts": VALID_UNTIL_TS,
- "verify_keys": {
- testverifykey_id: {
- "key": signedjson.key.encode_verify_key_base64(testverifykey)
- }
- },
- }
-
- # the response must be signed by both the origin server and the perspectives
- # server.
- signedjson.sign.sign_json(response, SERVER_NAME, testkey)
- self.mock_perspective_server.sign_response(response)
- return response
+ return self.build_perspectives_response(
+ SERVER_NAME, testkey, VALID_UNTIL_TS
+ )
def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs)
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-
- def post_json(destination, path, data, **kwargs):
- self.assertEqual(destination, self.mock_perspective_server.server_name)
- self.assertEqual(path, "/_matrix/key/v2/query")
- return {"server_keys": [response]}
-
- self.http_client.post_json.side_effect = post_json
-
+ self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
return self.get_success(fetcher.get_keys(keys_to_fetch))
# start with a valid response so we can check we are testing the right thing
@@ -580,7 +605,7 @@ def run_in_context(f, *args, **kwargs):
# logs.
ctx.request = "testctx"
rv = yield f(*args, **kwargs)
- defer.returnValue(rv)
+ return rv
def _verify_json_for_server(kr, *args):
@@ -591,6 +616,6 @@ def _verify_json_for_server(kr, *args):
@defer.inlineCallbacks
def v():
rv1 = yield kr.verify_json_for_server(*args)
- defer.returnValue(rv1)
+ return rv1
return run_in_context(v)
|