diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index f9ce609923..8ff1460c0d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -23,6 +23,7 @@ from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
+from twisted.internet.defer import Deferred, ensureDeferred
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
@@ -33,16 +34,17 @@ from synapse.crypto.keyring import (
)
from synapse.logging.context import (
LoggingContext,
- PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.storage.keys import FetchKeyResult
from tests import unittest
+from tests.test_utils import make_awaitable
+from tests.unittest import logcontext_clean
-class MockPerspectiveServer(object):
+class MockPerspectiveServer:
def __init__(self):
self.server_name = "mock_server"
self.key = signedjson.key.generate_signing_key(0)
@@ -66,56 +68,42 @@ class MockPerspectiveServer(object):
signedjson.sign.sign_json(res, self.server_name, self.key)
+@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- self.mock_perspective_server = MockPerspectiveServer()
- self.http_client = Mock()
-
- config = self.default_config()
- config["trusted_key_servers"] = [
- {
- "server_name": self.mock_perspective_server.server_name,
- "verify_keys": self.mock_perspective_server.get_verify_keys(),
- }
- ]
-
- return self.setup_test_homeserver(
- handlers=None, http_client=self.http_client, config=config
- )
-
- def check_context(self, _, expected):
+ def check_context(self, val, expected):
self.assertEquals(getattr(current_context(), "request", None), expected)
+ return val
def test_verify_json_objects_for_server_awaits_previous_requests(self):
- key1 = signedjson.key.generate_signing_key(1)
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock()
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
- kr = keyring.Keyring(self.hs)
+ # a signed object that we are going to try to validate
+ key1 = signedjson.key.generate_signing_key(1)
json1 = {}
signedjson.sign.sign_json(json1, "server10", key1)
- persp_resp = {
- "server_keys": [
- self.mock_perspective_server.get_signed_key(
- "server10", signedjson.key.get_verify_key(key1)
- )
- ]
- }
- persp_deferred = defer.Deferred()
+ # start off a first set of lookups. We make the mock fetcher block until this
+ # deferred completes.
+ first_lookup_deferred = Deferred()
- @defer.inlineCallbacks
- def get_perspectives(**kwargs):
- self.assertEquals(current_context().request, "11")
- with PreserveLoggingContext():
- yield persp_deferred
- return persp_resp
+ async def first_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_11")
+ self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
- self.http_client.post_json.side_effect = get_perspectives
+ await make_deferred_yieldable(first_lookup_deferred)
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
+
+ mock_fetcher.get_keys.side_effect = first_lookup_fetch
- # start off a first set of lookups
- @defer.inlineCallbacks
- def first_lookup():
- with LoggingContext("11") as context_11:
- context_11.request = "11"
+ async def first_lookup():
+ with LoggingContext("context_11") as context_11:
+ context_11.request = "context_11"
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
@@ -124,7 +112,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
- yield res_deferreds[1]
+ await res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
@@ -132,45 +120,51 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds[0])
+ await make_deferred_yieldable(res_deferreds[0])
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
+ d0 = ensureDeferred(first_lookup())
- d0 = first_lookup()
-
- # wait a tick for it to send the request to the perspectives server
- # (it first tries the datastore)
- self.pump()
- self.http_client.post_json.assert_called_once()
+ mock_fetcher.get_keys.assert_called_once()
# a second request for a server with outstanding requests
# should block rather than start a second call
- @defer.inlineCallbacks
- def second_lookup():
- with LoggingContext("12") as context_12:
- context_12.request = "12"
- self.http_client.post_json.reset_mock()
- self.http_client.post_json.return_value = defer.Deferred()
+
+ async def second_lookup_fetch(keys_to_fetch):
+ self.assertEquals(current_context().request, "context_12")
+ return {
+ "server10": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
+ }
+ }
+
+ mock_fetcher.get_keys.reset_mock()
+ mock_fetcher.get_keys.side_effect = second_lookup_fetch
+ second_lookup_state = [0]
+
+ async def second_lookup():
+ with LoggingContext("context_12") as context_12:
+ context_12.request = "context_12"
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
- yield make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 1
+ await make_deferred_yieldable(res_deferreds_2[0])
+ second_lookup_state[0] = 2
- # let verify_json_objects_for_server finish its work before we kill the
- # logcontext
- yield self.clock.sleep(0)
-
- d2 = second_lookup()
+ d2 = ensureDeferred(second_lookup())
self.pump()
- self.http_client.post_json.assert_not_called()
+ # the second request should be pending, but the fetcher should not yet have been
+ # called
+ self.assertEqual(second_lookup_state[0], 1)
+ mock_fetcher.get_keys.assert_not_called()
# complete the first request
- persp_deferred.callback(persp_resp)
+ first_lookup_deferred.callback(None)
+
+ # and now both verifications should succeed.
self.get_success(d0)
self.get_success(d2)
@@ -190,7 +184,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
- self.failureResultOf(d, SynapseError)
+ self.get_failure(d, SynapseError)
# should succeed on a signed object
d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
@@ -202,7 +196,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
with a null `ts_valid_until_ms`
"""
mock_fetcher = keyring.KeyFetcher()
- mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+ mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring(
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
@@ -221,7 +215,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
- self.failureResultOf(d, SynapseError)
+ self.get_failure(d, SynapseError)
# should fail on a signed object with a non-zero minimum_valid_until_ms,
# as it tries to refetch the keys and fails.
@@ -245,17 +239,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1)
- def get_keys(keys_to_fetch):
+ async def get_keys(keys_to_fetch):
# there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
- }
+ return {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
- )
+ }
mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -282,25 +274,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1)
- def get_keys1(keys_to_fetch):
+ async def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
- }
- }
- )
+ return {
+ "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+ }
- def get_keys2(keys_to_fetch):
+ async def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
- }
+ return {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
- )
+ }
mock_fetcher1 = keyring.KeyFetcher()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
@@ -325,6 +311,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher2.get_keys.assert_called_once()
+@logcontext_clean
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
@@ -355,7 +342,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
}
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
- def get_json(destination, path, **kwargs):
+ async def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server/key1")
return response
@@ -444,7 +431,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect a perspectives-server key query
"""
- def post_json(destination, path, data, **kwargs):
+ async def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")
@@ -580,14 +567,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# remove the perspectives server's signature
response = build_response()
del response["signatures"][self.mock_perspective_server.server_name]
- self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
# remove the origin server's signature
response = build_response()
del response["signatures"][SERVER_NAME]
- self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
|