diff options
-rw-r--r-- | tests/crypto/test_keyring.py | 84 |
1 files changed, 83 insertions, 1 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 1b9696748f..f414418adb 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -25,10 +25,12 @@ from signedjson.types import SigningKey, VerifyKey from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import NoResource, Resource -from synapse.api.errors import SynapseError +from synapse.api.errors import HttpResponseException, SynapseError from synapse.crypto import keyring from synapse.crypto.keyring import ( + InternalWorkerRequestKeyFetcher, PerspectivesKeyFetcher, ServerKeyFetcher, StoreKeyFetcher, @@ -39,12 +41,15 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) +from synapse.rest.key.v2 import KeyResource from synapse.server import HomeServer from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict from synapse.util import Clock +from synapse.util.httpresourcetree import create_resource_tree from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.test_utils import make_awaitable from tests.unittest import logcontext_clean, override_config @@ -757,6 +762,83 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") +class InternalWorkerRequestKeyFetcherTestCase(BaseMultiWorkerStreamTestCase): + def create_test_resource(self) -> Resource: # type: ignore[override] + return create_resource_tree( + {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource() + ) + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + config.update( + federation_sender_instances=["federation_sender1"], + instance_map={ + "federation_sender1": {"host": "testserv", "port": 1001}, + }, + ) + return config + + def test_key_fetching_works_across_workers(self) -> None: + """Test that a non-fed-sender worker requests keys via a fed-sender.""" + mock_http_client = Mock() + + # 1. Mock out the response from the notary server. + async def mock_post_json(*args: Any, **kwargs: Any) -> JsonDict: + """Mock the request to the notary server.""" + if kwargs.get("path") != "/_matrix/key/v2/query": + raise HttpResponseException(500, "ruh", b"roh") + return {"server_keys": []} + + mock_http_client.post_json = mock_post_json + + # 2. Build a valid response to /_matrix/key/v2/server for the server being + # queried. + SERVER_NAME = "server2" + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + response = { + "server_name": SERVER_NAME, + "old_verify_keys": {}, + "valid_until_ts": VALID_UNTIL_TS, + "verify_keys": { + testverifykey_id: { + "key": signedjson.key.encode_verify_key_base64(testverifykey) + } + }, + } + signedjson.sign.sign_json(response, SERVER_NAME, testkey) + + async def mock_get_json(*args: Any, **kwargs: Any) -> JsonDict: + if kwargs.get("path") != "/_matrix/key/v2/server": + raise HttpResponseException(500, "ruh", b"roh") + return response + + mock_http_client.get_json = mock_get_json + + # 3. Make a federation homeserver to actually make the request. + self.make_worker_hs( + "synapse.app.generic_worker", + { + "worker_name": "federation_sender1", + "federation_sender_instances": ["federation_sender1"], + }, + federation_http_client=mock_http_client, + ) + + # 4. Use the via-fed-sender fetcher to get keys. + fetcher = InternalWorkerRequestKeyFetcher(self.hs) + keys = self.get_success( + fetcher.get_keys(SERVER_NAME, [testverifykey_id], 0), by=0.1 + ) + k = keys[testverifykey_id] + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") + + def get_key_id(key: SigningKey) -> str: """Get the matrix ID tag for a given SigningKey or VerifyKey""" return "%s:%s" % (key.alg, key.version) |