summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/crypto/test_keyring.py139
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py130
-rw-r--r--tests/server.py6
-rw-r--r--tests/storage/test_background_update.py72
-rw-r--r--tests/storage/test_event_federation.py15
-rw-r--r--tests/unittest.py17
6 files changed, 281 insertions, 98 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 8efd39c7f7..34d5895f18 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,6 +19,7 @@ from mock import Mock
 import canonicaljson
 import signedjson.key
 import signedjson.sign
+from nacl.signing import SigningKey
 from signedjson.key import encode_verify_key_base64, get_verify_key
 
 from twisted.internet import defer
@@ -412,34 +413,37 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
             handlers=None, http_client=self.http_client, config=config
         )
 
-    def test_get_keys_from_perspectives(self):
-        # arbitrarily advance the clock a bit
-        self.reactor.advance(100)
-
-        fetcher = PerspectivesKeyFetcher(self.hs)
-
-        SERVER_NAME = "server2"
-        testkey = signedjson.key.generate_signing_key("ver1")
-        testverifykey = signedjson.key.get_verify_key(testkey)
-        testverifykey_id = "ed25519:ver1"
-        VALID_UNTIL_TS = 200 * 1000
+    def build_perspectives_response(
+        self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
+    ) -> dict:
+        """
+        Build a valid perspectives server response to a request for the given key
+        """
+        verify_key = signedjson.key.get_verify_key(signing_key)
+        verifykey_id = "%s:%s" % (verify_key.alg, verify_key.version)
 
-        # valid response
         response = {
-            "server_name": SERVER_NAME,
+            "server_name": server_name,
             "old_verify_keys": {},
-            "valid_until_ts": VALID_UNTIL_TS,
+            "valid_until_ts": valid_until_ts,
             "verify_keys": {
-                testverifykey_id: {
-                    "key": signedjson.key.encode_verify_key_base64(testverifykey)
+                verifykey_id: {
+                    "key": signedjson.key.encode_verify_key_base64(verify_key)
                 }
             },
         }
-
         # the response must be signed by both the origin server and the perspectives
         # server.
-        signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+        signedjson.sign.sign_json(response, server_name, signing_key)
         self.mock_perspective_server.sign_response(response)
+        return response
+
+    def expect_outgoing_key_query(
+        self, expected_server_name: str, expected_key_id: str, response: dict
+    ) -> None:
+        """
+        Tell the mock http client to expect a perspectives-server key query
+        """
 
         def post_json(destination, path, data, **kwargs):
             self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -447,11 +451,79 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
             # check that the request is for the expected key
             q = data["server_keys"]
-            self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
+            self.assertEqual(list(q[expected_server_name].keys()), [expected_key_id])
             return {"server_keys": [response]}
 
         self.http_client.post_json.side_effect = post_json
 
+    def test_get_keys_from_perspectives(self):
+        # arbitrarily advance the clock a bit
+        self.reactor.advance(100)
+
+        fetcher = PerspectivesKeyFetcher(self.hs)
+
+        SERVER_NAME = "server2"
+        testkey = signedjson.key.generate_signing_key("ver1")
+        testverifykey = signedjson.key.get_verify_key(testkey)
+        testverifykey_id = "ed25519:ver1"
+        VALID_UNTIL_TS = 200 * 1000
+
+        response = self.build_perspectives_response(
+            SERVER_NAME, testkey, VALID_UNTIL_TS,
+        )
+
+        self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
+        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+        self.assertIn(SERVER_NAME, keys)
+        k = keys[SERVER_NAME][testverifykey_id]
+        self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+        self.assertEqual(k.verify_key, testverifykey)
+        self.assertEqual(k.verify_key.alg, "ed25519")
+        self.assertEqual(k.verify_key.version, "ver1")
+
+        # check that the perspectives store is correctly updated
+        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
+        key_json = self.get_success(
+            self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+        )
+        res = key_json[lookup_triplet]
+        self.assertEqual(len(res), 1)
+        res = res[0]
+        self.assertEqual(res["key_id"], testverifykey_id)
+        self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
+        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+
+        self.assertEqual(
+            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
+        )
+
+    def test_get_perspectives_own_key(self):
+        """Check that we can get the perspectives server's own keys
+
+        This is slightly complicated by the fact that the perspectives server may
+        use different keys for signing notary responses.
+        """
+
+        # arbitrarily advance the clock a bit
+        self.reactor.advance(100)
+
+        fetcher = PerspectivesKeyFetcher(self.hs)
+
+        SERVER_NAME = self.mock_perspective_server.server_name
+        testkey = signedjson.key.generate_signing_key("ver1")
+        testverifykey = signedjson.key.get_verify_key(testkey)
+        testverifykey_id = "ed25519:ver1"
+        VALID_UNTIL_TS = 200 * 1000
+
+        response = self.build_perspectives_response(
+            SERVER_NAME, testkey, VALID_UNTIL_TS
+        )
+
+        self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
+
         keys_to_fetch = {SERVER_NAME: {"key1": 0}}
         keys = self.get_success(fetcher.get_keys(keys_to_fetch))
         self.assertIn(SERVER_NAME, keys)
@@ -490,35 +562,14 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         VALID_UNTIL_TS = 200 * 1000
 
         def build_response():
-            # valid response
-            response = {
-                "server_name": SERVER_NAME,
-                "old_verify_keys": {},
-                "valid_until_ts": VALID_UNTIL_TS,
-                "verify_keys": {
-                    testverifykey_id: {
-                        "key": signedjson.key.encode_verify_key_base64(testverifykey)
-                    }
-                },
-            }
-
-            # the response must be signed by both the origin server and the perspectives
-            # server.
-            signedjson.sign.sign_json(response, SERVER_NAME, testkey)
-            self.mock_perspective_server.sign_response(response)
-            return response
+            return self.build_perspectives_response(
+                SERVER_NAME, testkey, VALID_UNTIL_TS
+            )
 
         def get_key_from_perspectives(response):
             fetcher = PerspectivesKeyFetcher(self.hs)
             keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-
-            def post_json(destination, path, data, **kwargs):
-                self.assertEqual(destination, self.mock_perspective_server.server_name)
-                self.assertEqual(path, "/_matrix/key/v2/query")
-                return {"server_keys": [response]}
-
-            self.http_client.post_json.side_effect = post_json
-
+            self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
             return self.get_success(fetcher.get_keys(keys_to_fetch))
 
         # start with a valid response so we can check we are testing the right thing
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
new file mode 100644
index 0000000000..d8246b4e78
--- /dev/null
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import urllib.parse
+from io import BytesIO
+
+from mock import Mock
+
+import signedjson.key
+from nacl.signing import SigningKey
+from signedjson.sign import sign_json
+
+from twisted.web.resource import NoResource
+
+from synapse.http.site import SynapseRequest
+from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.util.httpresourcetree import create_resource_tree
+
+from tests import unittest
+from tests.server import FakeChannel, wait_until_result
+
+
+class RemoteKeyResourceTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        self.http_client = Mock()
+        return self.setup_test_homeserver(http_client=self.http_client)
+
+    def create_test_json_resource(self):
+        return create_resource_tree(
+            {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
+        )
+
+    def expect_outgoing_key_request(
+        self, server_name: str, signing_key: SigningKey
+    ) -> None:
+        """
+        Tell the mock http client to expect an outgoing GET request for the given key
+        """
+
+        def get_json(destination, path, ignore_backoff=False, **kwargs):
+            self.assertTrue(ignore_backoff)
+            self.assertEqual(destination, server_name)
+            key_id = "%s:%s" % (signing_key.alg, signing_key.version)
+            self.assertEqual(
+                path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),)
+            )
+
+            response = {
+                "server_name": server_name,
+                "old_verify_keys": {},
+                "valid_until_ts": 200 * 1000,
+                "verify_keys": {
+                    key_id: {
+                        "key": signedjson.key.encode_verify_key_base64(
+                            signing_key.verify_key
+                        )
+                    }
+                },
+            }
+            sign_json(response, server_name, signing_key)
+            return response
+
+        self.http_client.get_json.side_effect = get_json
+
+    def make_notary_request(self, server_name: str, key_id: str) -> dict:
+        """Send a GET request to the test server requesting the given key.
+
+        Checks that the response is a 200 and returns the decoded json body.
+        """
+        channel = FakeChannel(self.site, self.reactor)
+        req = SynapseRequest(channel)
+        req.content = BytesIO(b"")
+        req.requestReceived(
+            b"GET",
+            b"/_matrix/key/v2/query/%s/%s"
+            % (server_name.encode("utf-8"), key_id.encode("utf-8")),
+            b"1.1",
+        )
+        wait_until_result(self.reactor, req)
+        self.assertEqual(channel.code, 200)
+        resp = channel.json_body
+        return resp
+
+    def test_get_key(self):
+        """Fetch a remote key"""
+        SERVER_NAME = "remote.server"
+        testkey = signedjson.key.generate_signing_key("ver1")
+        self.expect_outgoing_key_request(SERVER_NAME, testkey)
+
+        resp = self.make_notary_request(SERVER_NAME, "ed25519:ver1")
+        keys = resp["server_keys"]
+        self.assertEqual(len(keys), 1)
+
+        self.assertIn("ed25519:ver1", keys[0]["verify_keys"])
+        self.assertEqual(len(keys[0]["verify_keys"]), 1)
+
+        # it should be signed by both the origin server and the notary
+        self.assertIn(SERVER_NAME, keys[0]["signatures"])
+        self.assertIn(self.hs.hostname, keys[0]["signatures"])
+
+    def test_get_own_key(self):
+        """Fetch our own key"""
+        testkey = signedjson.key.generate_signing_key("ver1")
+        self.expect_outgoing_key_request(self.hs.hostname, testkey)
+
+        resp = self.make_notary_request(self.hs.hostname, "ed25519:ver1")
+        keys = resp["server_keys"]
+        self.assertEqual(len(keys), 1)
+
+        # it should be signed by both itself, and the notary signing key
+        sigs = keys[0]["signatures"]
+        self.assertEqual(len(sigs), 1)
+        self.assertIn(self.hs.hostname, sigs)
+        oursigs = sigs[self.hs.hostname]
+        self.assertEqual(len(oursigs), 2)
+
+        # and both keys should be present in the verify_keys section
+        self.assertIn("ed25519:ver1", keys[0]["verify_keys"])
+        self.assertIn("ed25519:a_lPym", keys[0]["verify_keys"])
diff --git a/tests/server.py b/tests/server.py
index a554dfdd57..1644710aa0 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -20,6 +20,7 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http import unquote
 from twisted.web.http_headers import Headers
+from twisted.web.server import Site
 
 from synapse.http.site import SynapseRequest
 from synapse.util import Clock
@@ -42,6 +43,7 @@ class FakeChannel(object):
     wire).
     """
 
+    site = attr.ib(type=Site)
     _reactor = attr.ib()
     result = attr.ib(default=attr.Factory(dict))
     _producer = None
@@ -176,9 +178,9 @@ def make_request(
         content = content.encode("utf8")
 
     site = FakeSite()
-    channel = FakeChannel(reactor)
+    channel = FakeChannel(site, reactor)
 
-    req = request(site, channel)
+    req = request(channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
     req.postpath = list(map(unquote, path[1:].split(b"/")))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index aec76f4ab1..ae14fb407d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -2,44 +2,37 @@ from mock import Mock
 
 from twisted.internet import defer
 
+from synapse.storage.background_updates import BackgroundUpdater
+
 from tests import unittest
-from tests.utils import setup_test_homeserver
 
 
-class BackgroundUpdateTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = hs.get_datastore()
-        self.clock = hs.get_clock()
+class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, homeserver):
+        self.updates = self.hs.get_datastore().db.updates  # type: BackgroundUpdater
+        # the base test class should have run the real bg updates for us
+        self.assertTrue(self.updates.has_completed_background_updates())
 
         self.update_handler = Mock()
-
-        yield self.store.db.updates.register_background_update_handler(
+        self.updates.register_background_update_handler(
             "test_update", self.update_handler
         )
 
-        # run the real background updates, to get them out the way
-        # (perhaps we should run them as part of the test HS setup, since we
-        # run all of the other schema setup stuff there?)
-        while True:
-            res = yield self.store.db.updates.do_next_background_update(1000)
-            if res is None:
-                break
-
-    @defer.inlineCallbacks
     def test_do_background_update(self):
-        desired_count = 1000
+        # the time we claim each update takes
         duration_ms = 42
 
+        # the target runtime for each bg update
+        target_background_update_duration_ms = 50000
+
         # first step: make a bit of progress
         @defer.inlineCallbacks
         def update(progress, count):
-            self.clock.advance_time_msec(count * duration_ms)
+            yield self.clock.sleep((count * duration_ms) / 1000)
             progress = {"my_key": progress["my_key"] + 1}
-            yield self.store.db.runInteraction(
+            yield self.hs.get_datastore().db.runInteraction(
                 "update_progress",
-                self.store.db.updates._background_update_progress_txn,
+                self.updates._background_update_progress_txn,
                 "test_update",
                 progress,
             )
@@ -47,37 +40,46 @@ class BackgroundUpdateTestCase(unittest.TestCase):
 
         self.update_handler.side_effect = update
 
-        yield self.store.db.updates.start_background_update(
-            "test_update", {"my_key": 1}
+        self.get_success(
+            self.updates.start_background_update("test_update", {"my_key": 1})
         )
-
         self.update_handler.reset_mock()
-        result = yield self.store.db.updates.do_next_background_update(
-            duration_ms * desired_count
+        res = self.get_success(
+            self.updates.do_next_background_update(
+                target_background_update_duration_ms
+            ),
+            by=0.1,
         )
-        self.assertIsNotNone(result)
+        self.assertIsNotNone(res)
+
+        # on the first call, we should get run with the default background update size
         self.update_handler.assert_called_once_with(
-            {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE
+            {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE
         )
 
         # second step: complete the update
+        # we should now get run with a much bigger number of items to update
         @defer.inlineCallbacks
         def update(progress, count):
-            yield self.store.db.updates._end_background_update("test_update")
+            self.assertEqual(progress, {"my_key": 2})
+            self.assertAlmostEqual(
+                count, target_background_update_duration_ms / duration_ms, places=0,
+            )
+            yield self.updates._end_background_update("test_update")
             return count
 
         self.update_handler.side_effect = update
         self.update_handler.reset_mock()
-        result = yield self.store.db.updates.do_next_background_update(
-            duration_ms * desired_count
+        result = self.get_success(
+            self.updates.do_next_background_update(target_background_update_duration_ms)
         )
         self.assertIsNotNone(result)
-        self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
+        self.update_handler.assert_called_once()
 
         # third step: we don't expect to be called any more
         self.update_handler.reset_mock()
-        result = yield self.store.db.updates.do_next_background_update(
-            duration_ms * desired_count
+        result = self.get_success(
+            self.updates.do_next_background_update(target_background_update_duration_ms)
         )
         self.assertIsNone(result)
         self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index eadfb90a22..a331517f4d 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -60,21 +60,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
                 (event_id, bytearray(b"ffff")),
             )
 
-        for i in range(0, 11):
+        for i in range(0, 20):
             yield self.store.db.runInteraction("insert", insert_event, i)
 
-        # this should get the last five and five others
+        # this should get the last ten
         r = yield self.store.get_prev_events_for_room(room_id)
         self.assertEqual(10, len(r))
-        for i in range(0, 5):
-            el = r[i]
-            depth = el[2]
-            self.assertEqual(10 - i, depth)
-
-        for i in range(5, 5):
-            el = r[i]
-            depth = el[2]
-            self.assertLessEqual(5, depth)
+        for i in range(0, 10):
+            self.assertEqual("$event_%i:local" % (19 - i), r[i])
 
     @defer.inlineCallbacks
     def test_get_rooms_with_many_extremities(self):
diff --git a/tests/unittest.py b/tests/unittest.py
index b30b7d1718..ddcd4becfe 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -36,7 +36,7 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.config.ratelimiting import FederationRateLimitConfig
 from synapse.federation.transport import server as federation_server
 from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.logging.context import LoggingContext
 from synapse.server import HomeServer
 from synapse.types import Requester, UserID, create_requester
@@ -210,6 +210,15 @@ class HomeserverTestCase(TestCase):
         # Register the resources
         self.resource = self.create_test_json_resource()
 
+        # create a site to wrap the resource.
+        self.site = SynapseSite(
+            logger_name="synapse.access.http.fake",
+            site_tag="test",
+            config={},
+            resource=self.resource,
+            server_version_string="1",
+        )
+
         from tests.rest.client.v1.utils import RestHelper
 
         self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
@@ -522,10 +531,6 @@ class HomeserverTestCase(TestCase):
         secrets = self.hs.get_secrets()
         requester = Requester(user, None, False, None, None)
 
-        prev_events_and_hashes = None
-        if prev_event_ids:
-            prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
-
         event, context = self.get_success(
             event_creator.create_event(
                 requester,
@@ -535,7 +540,7 @@ class HomeserverTestCase(TestCase):
                     "sender": user.to_string(),
                     "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
                 },
-                prev_events_and_hashes=prev_events_and_hashes,
+                prev_event_ids=prev_event_ids,
             )
         )