diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 8efd39c7f7..34d5895f18 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,6 +19,7 @@ from mock import Mock
import canonicaljson
import signedjson.key
import signedjson.sign
+from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from twisted.internet import defer
@@ -412,34 +413,37 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
handlers=None, http_client=self.http_client, config=config
)
- def test_get_keys_from_perspectives(self):
- # arbitrarily advance the clock a bit
- self.reactor.advance(100)
-
- fetcher = PerspectivesKeyFetcher(self.hs)
-
- SERVER_NAME = "server2"
- testkey = signedjson.key.generate_signing_key("ver1")
- testverifykey = signedjson.key.get_verify_key(testkey)
- testverifykey_id = "ed25519:ver1"
- VALID_UNTIL_TS = 200 * 1000
+ def build_perspectives_response(
+ self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
+ ) -> dict:
+ """
+ Build a valid perspectives server response to a request for the given key
+ """
+ verify_key = signedjson.key.get_verify_key(signing_key)
+ verifykey_id = "%s:%s" % (verify_key.alg, verify_key.version)
- # valid response
response = {
- "server_name": SERVER_NAME,
+ "server_name": server_name,
"old_verify_keys": {},
- "valid_until_ts": VALID_UNTIL_TS,
+ "valid_until_ts": valid_until_ts,
"verify_keys": {
- testverifykey_id: {
- "key": signedjson.key.encode_verify_key_base64(testverifykey)
+ verifykey_id: {
+ "key": signedjson.key.encode_verify_key_base64(verify_key)
}
},
}
-
# the response must be signed by both the origin server and the perspectives
# server.
- signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ signedjson.sign.sign_json(response, server_name, signing_key)
self.mock_perspective_server.sign_response(response)
+ return response
+
+ def expect_outgoing_key_query(
+ self, expected_server_name: str, expected_key_id: str, response: dict
+ ) -> None:
+ """
+ Tell the mock http client to expect a perspectives-server key query
+ """
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -447,11 +451,79 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# check that the request is for the expected key
q = data["server_keys"]
- self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
+ self.assertEqual(list(q[expected_server_name].keys()), [expected_key_id])
return {"server_keys": [response]}
self.http_client.post_json.side_effect = post_json
+ def test_get_keys_from_perspectives(self):
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
+ SERVER_NAME = "server2"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ response = self.build_perspectives_response(
+ SERVER_NAME, testkey, VALID_UNTIL_TS,
+ )
+
+ self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+ self.assertIn(SERVER_NAME, keys)
+ k = keys[SERVER_NAME][testverifykey_id]
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
+
+ # check that the perspectives store is correctly updated
+ lookup_triplet = (SERVER_NAME, testverifykey_id, None)
+ key_json = self.get_success(
+ self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+ )
+ res = key_json[lookup_triplet]
+ self.assertEqual(len(res), 1)
+ res = res[0]
+ self.assertEqual(res["key_id"], testverifykey_id)
+ self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
+ self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+ self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+
+ self.assertEqual(
+ bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
+ )
+
+ def test_get_perspectives_own_key(self):
+ """Check that we can get the perspectives server's own keys
+
+ This is slightly complicated by the fact that the perspectives server may
+ use different keys for signing notary responses.
+ """
+
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
+ SERVER_NAME = self.mock_perspective_server.server_name
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ response = self.build_perspectives_response(
+ SERVER_NAME, testkey, VALID_UNTIL_TS
+ )
+
+ self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
@@ -490,35 +562,14 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
VALID_UNTIL_TS = 200 * 1000
def build_response():
- # valid response
- response = {
- "server_name": SERVER_NAME,
- "old_verify_keys": {},
- "valid_until_ts": VALID_UNTIL_TS,
- "verify_keys": {
- testverifykey_id: {
- "key": signedjson.key.encode_verify_key_base64(testverifykey)
- }
- },
- }
-
- # the response must be signed by both the origin server and the perspectives
- # server.
- signedjson.sign.sign_json(response, SERVER_NAME, testkey)
- self.mock_perspective_server.sign_response(response)
- return response
+ return self.build_perspectives_response(
+ SERVER_NAME, testkey, VALID_UNTIL_TS
+ )
def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs)
keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-
- def post_json(destination, path, data, **kwargs):
- self.assertEqual(destination, self.mock_perspective_server.server_name)
- self.assertEqual(path, "/_matrix/key/v2/query")
- return {"server_keys": [response]}
-
- self.http_client.post_json.side_effect = post_json
-
+ self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
return self.get_success(fetcher.get_keys(keys_to_fetch))
# start with a valid response so we can check we are testing the right thing
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
new file mode 100644
index 0000000000..d8246b4e78
--- /dev/null
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import urllib.parse
+from io import BytesIO
+
+from mock import Mock
+
+import signedjson.key
+from nacl.signing import SigningKey
+from signedjson.sign import sign_json
+
+from twisted.web.resource import NoResource
+
+from synapse.http.site import SynapseRequest
+from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.util.httpresourcetree import create_resource_tree
+
+from tests import unittest
+from tests.server import FakeChannel, wait_until_result
+
+
+class RemoteKeyResourceTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.http_client = Mock()
+ return self.setup_test_homeserver(http_client=self.http_client)
+
+ def create_test_json_resource(self):
+ return create_resource_tree(
+ {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
+ )
+
+ def expect_outgoing_key_request(
+ self, server_name: str, signing_key: SigningKey
+ ) -> None:
+ """
+ Tell the mock http client to expect an outgoing GET request for the given key
+ """
+
+ def get_json(destination, path, ignore_backoff=False, **kwargs):
+ self.assertTrue(ignore_backoff)
+ self.assertEqual(destination, server_name)
+ key_id = "%s:%s" % (signing_key.alg, signing_key.version)
+ self.assertEqual(
+ path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),)
+ )
+
+ response = {
+ "server_name": server_name,
+ "old_verify_keys": {},
+ "valid_until_ts": 200 * 1000,
+ "verify_keys": {
+ key_id: {
+ "key": signedjson.key.encode_verify_key_base64(
+ signing_key.verify_key
+ )
+ }
+ },
+ }
+ sign_json(response, server_name, signing_key)
+ return response
+
+ self.http_client.get_json.side_effect = get_json
+
+ def make_notary_request(self, server_name: str, key_id: str) -> dict:
+ """Send a GET request to the test server requesting the given key.
+
+ Checks that the response is a 200 and returns the decoded json body.
+ """
+ channel = FakeChannel(self.site, self.reactor)
+ req = SynapseRequest(channel)
+ req.content = BytesIO(b"")
+ req.requestReceived(
+ b"GET",
+ b"/_matrix/key/v2/query/%s/%s"
+ % (server_name.encode("utf-8"), key_id.encode("utf-8")),
+ b"1.1",
+ )
+ wait_until_result(self.reactor, req)
+ self.assertEqual(channel.code, 200)
+ resp = channel.json_body
+ return resp
+
+ def test_get_key(self):
+ """Fetch a remote key"""
+ SERVER_NAME = "remote.server"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ self.expect_outgoing_key_request(SERVER_NAME, testkey)
+
+ resp = self.make_notary_request(SERVER_NAME, "ed25519:ver1")
+ keys = resp["server_keys"]
+ self.assertEqual(len(keys), 1)
+
+ self.assertIn("ed25519:ver1", keys[0]["verify_keys"])
+ self.assertEqual(len(keys[0]["verify_keys"]), 1)
+
+ # it should be signed by both the origin server and the notary
+ self.assertIn(SERVER_NAME, keys[0]["signatures"])
+ self.assertIn(self.hs.hostname, keys[0]["signatures"])
+
+ def test_get_own_key(self):
+ """Fetch our own key"""
+ testkey = signedjson.key.generate_signing_key("ver1")
+ self.expect_outgoing_key_request(self.hs.hostname, testkey)
+
+ resp = self.make_notary_request(self.hs.hostname, "ed25519:ver1")
+ keys = resp["server_keys"]
+ self.assertEqual(len(keys), 1)
+
+ # it should be signed by both itself, and the notary signing key
+ sigs = keys[0]["signatures"]
+ self.assertEqual(len(sigs), 1)
+ self.assertIn(self.hs.hostname, sigs)
+ oursigs = sigs[self.hs.hostname]
+ self.assertEqual(len(oursigs), 2)
+
+ # and both keys should be present in the verify_keys section
+ self.assertIn("ed25519:ver1", keys[0]["verify_keys"])
+ self.assertIn("ed25519:a_lPym", keys[0]["verify_keys"])
diff --git a/tests/server.py b/tests/server.py
index a554dfdd57..1644710aa0 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -20,6 +20,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http import unquote
from twisted.web.http_headers import Headers
+from twisted.web.server import Site
from synapse.http.site import SynapseRequest
from synapse.util import Clock
@@ -42,6 +43,7 @@ class FakeChannel(object):
wire).
"""
+ site = attr.ib(type=Site)
_reactor = attr.ib()
result = attr.ib(default=attr.Factory(dict))
_producer = None
@@ -176,9 +178,9 @@ def make_request(
content = content.encode("utf8")
site = FakeSite()
- channel = FakeChannel(reactor)
+ channel = FakeChannel(site, reactor)
- req = request(site, channel)
+ req = request(channel)
req.process = lambda: b""
req.content = BytesIO(content)
req.postpath = list(map(unquote, path[1:].split(b"/")))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index aec76f4ab1..ae14fb407d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -2,44 +2,37 @@ from mock import Mock
from twisted.internet import defer
+from synapse.storage.background_updates import BackgroundUpdater
+
from tests import unittest
-from tests.utils import setup_test_homeserver
-class BackgroundUpdateTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
- self.store = hs.get_datastore()
- self.clock = hs.get_clock()
+class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater
+ # the base test class should have run the real bg updates for us
+ self.assertTrue(self.updates.has_completed_background_updates())
self.update_handler = Mock()
-
- yield self.store.db.updates.register_background_update_handler(
+ self.updates.register_background_update_handler(
"test_update", self.update_handler
)
- # run the real background updates, to get them out the way
- # (perhaps we should run them as part of the test HS setup, since we
- # run all of the other schema setup stuff there?)
- while True:
- res = yield self.store.db.updates.do_next_background_update(1000)
- if res is None:
- break
-
- @defer.inlineCallbacks
def test_do_background_update(self):
- desired_count = 1000
+ # the time we claim each update takes
duration_ms = 42
+ # the target runtime for each bg update
+ target_background_update_duration_ms = 50000
+
# first step: make a bit of progress
@defer.inlineCallbacks
def update(progress, count):
- self.clock.advance_time_msec(count * duration_ms)
+ yield self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield self.store.db.runInteraction(
+ yield self.hs.get_datastore().db.runInteraction(
"update_progress",
- self.store.db.updates._background_update_progress_txn,
+ self.updates._background_update_progress_txn,
"test_update",
progress,
)
@@ -47,37 +40,46 @@ class BackgroundUpdateTestCase(unittest.TestCase):
self.update_handler.side_effect = update
- yield self.store.db.updates.start_background_update(
- "test_update", {"my_key": 1}
+ self.get_success(
+ self.updates.start_background_update("test_update", {"my_key": 1})
)
-
self.update_handler.reset_mock()
- result = yield self.store.db.updates.do_next_background_update(
- duration_ms * desired_count
+ res = self.get_success(
+ self.updates.do_next_background_update(
+ target_background_update_duration_ms
+ ),
+ by=0.1,
)
- self.assertIsNotNone(result)
+ self.assertIsNotNone(res)
+
+ # on the first call, we should get run with the default background update size
self.update_handler.assert_called_once_with(
- {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE
+ {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE
)
# second step: complete the update
+ # we should now get run with a much bigger number of items to update
@defer.inlineCallbacks
def update(progress, count):
- yield self.store.db.updates._end_background_update("test_update")
+ self.assertEqual(progress, {"my_key": 2})
+ self.assertAlmostEqual(
+ count, target_background_update_duration_ms / duration_ms, places=0,
+ )
+ yield self.updates._end_background_update("test_update")
return count
self.update_handler.side_effect = update
self.update_handler.reset_mock()
- result = yield self.store.db.updates.do_next_background_update(
- duration_ms * desired_count
+ result = self.get_success(
+ self.updates.do_next_background_update(target_background_update_duration_ms)
)
self.assertIsNotNone(result)
- self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
+ self.update_handler.assert_called_once()
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
- result = yield self.store.db.updates.do_next_background_update(
- duration_ms * desired_count
+ result = self.get_success(
+ self.updates.do_next_background_update(target_background_update_duration_ms)
)
self.assertIsNone(result)
self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index eadfb90a22..a331517f4d 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -60,21 +60,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
(event_id, bytearray(b"ffff")),
)
- for i in range(0, 11):
+ for i in range(0, 20):
yield self.store.db.runInteraction("insert", insert_event, i)
- # this should get the last five and five others
+ # this should get the last ten
r = yield self.store.get_prev_events_for_room(room_id)
self.assertEqual(10, len(r))
- for i in range(0, 5):
- el = r[i]
- depth = el[2]
- self.assertEqual(10 - i, depth)
-
- for i in range(5, 5):
- el = r[i]
- depth = el[2]
- self.assertLessEqual(5, depth)
+ for i in range(0, 10):
+ self.assertEqual("$event_%i:local" % (19 - i), r[i])
@defer.inlineCallbacks
def test_get_rooms_with_many_extremities(self):
diff --git a/tests/unittest.py b/tests/unittest.py
index b30b7d1718..ddcd4becfe 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -36,7 +36,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester
@@ -210,6 +210,15 @@ class HomeserverTestCase(TestCase):
# Register the resources
self.resource = self.create_test_json_resource()
+ # create a site to wrap the resource.
+ self.site = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag="test",
+ config={},
+ resource=self.resource,
+ server_version_string="1",
+ )
+
from tests.rest.client.v1.utils import RestHelper
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
@@ -522,10 +531,6 @@ class HomeserverTestCase(TestCase):
secrets = self.hs.get_secrets()
requester = Requester(user, None, False, None, None)
- prev_events_and_hashes = None
- if prev_event_ids:
- prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
-
event, context = self.get_success(
event_creator.create_event(
requester,
@@ -535,7 +540,7 @@ class HomeserverTestCase(TestCase):
"sender": user.to_string(),
"content": {"body": secrets.token_hex(), "msgtype": "m.text"},
},
- prev_events_and_hashes=prev_events_and_hashes,
+ prev_event_ids=prev_event_ids,
)
)
|