diff --git a/tests/__init__.py b/tests/__init__.py
index d3181f9403..f7fc502f01 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -21,4 +21,4 @@ import tests.patch_inline_callbacks
# attempt to do the patch before we load any synapse code
tests.patch_inline_callbacks.do_patch()
-util.DEFAULT_TIMEOUT_DURATION = 10
+util.DEFAULT_TIMEOUT_DURATION = 20
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3c79d4afe7..18121f4f6c 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,12 +19,14 @@ from mock import Mock
import canonicaljson
import signedjson.key
import signedjson.sign
+from signedjson.key import get_verify_key
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import PerspectivesKeyFetcher, ServerKeyFetcher
+from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
@@ -50,11 +52,11 @@ class MockPerspectiveServer(object):
key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
},
}
- return self.get_signed_response(res)
+ self.sign_response(res)
+ return res
- def get_signed_response(self, res):
+ def sign_response(self, res):
signedjson.sign.sign_json(res, self.server_name, self.key)
- return res
class KeyringTestCase(unittest.HomeserverTestCase):
@@ -80,7 +82,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# we run the lookup in a logcontext so that the patched inlineCallbacks can check
# it is doing the right thing with logcontexts.
wait_1_deferred = run_in_context(
- kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred}
+ kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
)
# there were no previous lookups, so the deferred should be ready
@@ -89,7 +91,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = run_in_context(
- kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred}
+ kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
@@ -132,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
context_11.request = "11"
res_deferreds = kr.verify_json_objects_for_server(
- [("server10", json1), ("server11", {})]
+ [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
)
# the unsigned json should be rejected pretty quickly
@@ -169,7 +171,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server(
- [("server10", json1)]
+ [("server10", json1, 0, "test")]
)
res_deferreds_2[0].addBoth(self.check_context, None)
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@@ -192,31 +194,125 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- r = self.hs.datastore.store_server_verify_key(
- "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
+ r = self.hs.datastore.store_server_verify_keys(
+ "server9",
+ time.time() * 1000,
+ [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
)
self.get_success(r)
+
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
- d = _verify_json_for_server(kr, "server9", {})
+ d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
self.failureResultOf(d, SynapseError)
- d = _verify_json_for_server(kr, "server9", json1)
- self.assertFalse(d.called)
+ # should suceed on a signed object
+ d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
+ # self.assertFalse(d.called)
self.get_success(d)
+ def test_verify_json_dedupes_key_requests(self):
+ """Two requests for the same key should be deduped."""
+ key1 = signedjson.key.generate_signing_key(1)
+
+ def get_keys(keys_to_fetch):
+ # there should only be one request object (with the max validity)
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+ }
+ }
+ )
+
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock(side_effect=get_keys)
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server1", key1)
+
+ # the first request should succeed; the second should fail because the key
+ # has expired
+ results = kr.verify_json_objects_for_server(
+ [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
+ )
+ self.assertEqual(len(results), 2)
+ self.get_success(results[0])
+ e = self.get_failure(results[1], SynapseError).value
+ self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+ self.assertEqual(e.code, 401)
+
+ # there should have been a single call to the fetcher
+ mock_fetcher.get_keys.assert_called_once()
+
+ def test_verify_json_falls_back_to_other_fetchers(self):
+ """If the first fetcher cannot provide a recent enough key, we fall back"""
+ key1 = signedjson.key.generate_signing_key(1)
+
+ def get_keys1(keys_to_fetch):
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
+ }
+ }
+ )
+
+ def get_keys2(keys_to_fetch):
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+ }
+ }
+ )
+
+ mock_fetcher1 = keyring.KeyFetcher()
+ mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
+ mock_fetcher2 = keyring.KeyFetcher()
+ mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server1", key1)
+
+ results = kr.verify_json_objects_for_server(
+ [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
+ )
+ self.assertEqual(len(results), 2)
+ self.get_success(results[0])
+ e = self.get_failure(results[1], SynapseError).value
+ self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+ self.assertEqual(e.code, 401)
+
+ # there should have been a single call to each fetcher
+ mock_fetcher1.get_keys.assert_called_once()
+ mock_fetcher2.get_keys.assert_called_once()
+
+
+class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.http_client = Mock()
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ return hs
+
def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
+ fetcher = ServerKeyFetcher(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
- VALID_UNTIL_TS = 1000
+ VALID_UNTIL_TS = 200 * 1000
# valid response
response = {
@@ -238,12 +334,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
k = keys[SERVER_NAME][testverifykey_id]
- self.assertEqual(k, testverifykey)
- self.assertEqual(k.alg, "ed25519")
- self.assertEqual(k.version, "ver1")
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -263,18 +360,29 @@ class KeyringTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
- # change the server name: it should cause a rejection
+ # change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
- self.get_failure(
- kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
- )
+
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+ self.assertEqual(keys, {})
+
+
+class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ self.mock_perspective_server = MockPerspectiveServer()
+ self.http_client = Mock()
+ hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+ keys = self.mock_perspective_server.get_verify_keys()
+ hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
+ return hs
def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
+ fetcher = PerspectivesKeyFetcher(self.hs)
+
SERVER_NAME = "server2"
- kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
@@ -292,9 +400,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
},
}
- persp_resp = {
- "server_keys": [self.mock_perspective_server.get_signed_response(response)]
- }
+ # the response must be signed by both the origin server and the perspectives
+ # server.
+ signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ self.mock_perspective_server.sign_response(response)
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -303,17 +412,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# check that the request is for the expected key
q = data["server_keys"]
self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
- return persp_resp
+ return {"server_keys": [response]}
self.http_client.post_json.side_effect = post_json
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
- self.assertEqual(k, testverifykey)
- self.assertEqual(k.alg, "ed25519")
- self.assertEqual(k.version, "ver1")
+ self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+ self.assertEqual(k.verify_key, testverifykey)
+ self.assertEqual(k.verify_key.alg, "ed25519")
+ self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -330,25 +440,96 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertEqual(
bytes(res["key_json"]),
- canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
+ canonicaljson.encode_canonical_json(response),
)
+ def test_invalid_perspectives_responses(self):
+ """Check that invalid responses from the perspectives server are rejected"""
+ # arbitrarily advance the clock a bit
+ self.reactor.advance(100)
+
+ SERVER_NAME = "server2"
+ testkey = signedjson.key.generate_signing_key("ver1")
+ testverifykey = signedjson.key.get_verify_key(testkey)
+ testverifykey_id = "ed25519:ver1"
+ VALID_UNTIL_TS = 200 * 1000
+
+ def build_response():
+ # valid response
+ response = {
+ "server_name": SERVER_NAME,
+ "old_verify_keys": {},
+ "valid_until_ts": VALID_UNTIL_TS,
+ "verify_keys": {
+ testverifykey_id: {
+ "key": signedjson.key.encode_verify_key_base64(testverifykey)
+ }
+ },
+ }
+
+ # the response must be signed by both the origin server and the perspectives
+ # server.
+ signedjson.sign.sign_json(response, SERVER_NAME, testkey)
+ self.mock_perspective_server.sign_response(response)
+ return response
+
+ def get_key_from_perspectives(response):
+ fetcher = PerspectivesKeyFetcher(self.hs)
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+
+ def post_json(destination, path, data, **kwargs):
+ self.assertEqual(destination, self.mock_perspective_server.server_name)
+ self.assertEqual(path, "/_matrix/key/v2/query")
+ return {"server_keys": [response]}
+
+ self.http_client.post_json.side_effect = post_json
+
+ return self.get_success(fetcher.get_keys(keys_to_fetch))
+
+ # start with a valid response so we can check we are testing the right thing
+ response = build_response()
+ keys = get_key_from_perspectives(response)
+ k = keys[SERVER_NAME][testverifykey_id]
+ self.assertEqual(k.verify_key, testverifykey)
+
+ # remove the perspectives server's signature
+ response = build_response()
+ del response["signatures"][self.mock_perspective_server.server_name]
+ self.http_client.post_json.return_value = {"server_keys": [response]}
+ keys = get_key_from_perspectives(response)
+ self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
+
+ # remove the origin server's signature
+ response = build_response()
+ del response["signatures"][SERVER_NAME]
+ self.http_client.post_json.return_value = {"server_keys": [response]}
+ keys = get_key_from_perspectives(response)
+ self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
+
+
+def get_key_id(key):
+ """Get the matrix ID tag for a given SigningKey or VerifyKey"""
+ return "%s:%s" % (key.alg, key.version)
+
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
- with LoggingContext("testctx"):
+ with LoggingContext("testctx") as ctx:
+ # we set the "request" prop to make it easier to follow what's going on in the
+ # logs.
+ ctx.request = "testctx"
rv = yield f(*args, **kwargs)
defer.returnValue(rv)
-def _verify_json_for_server(keyring, server_name, json_object):
+def _verify_json_for_server(kr, *args):
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
with the patched defer.inlineCallbacks.
"""
@defer.inlineCallbacks
def v():
- rv1 = yield keyring.verify_json_for_server(server_name, json_object)
+ rv1 = yield kr.verify_json_for_server(*args)
defer.returnValue(rv1)
return run_in_context(v)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
new file mode 100644
index 0000000000..1e3e5aec66
--- /dev/null
+++ b/tests/federation/test_complexity.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests import unittest
+
+
+class RoomComplexityTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self, name='test'):
+ config = super(RoomComplexityTests, self).default_config(name=name)
+ config["limit_large_remote_room_joins"] = True
+ config["limit_large_remote_room_complexity"] = 0.05
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ class Authenticator(object):
+ def authenticate_request(self, request, content):
+ return defer.succeed("otherserver.nottld")
+
+ ratelimiter = FederationRateLimiter(
+ clock,
+ FederationRateLimitConfig(
+ window_size=1,
+ sleep_limit=1,
+ sleep_msec=1,
+ reject_limit=1000,
+ concurrent_requests=1000,
+ ),
+ )
+ server.register_servlets(
+ homeserver, self.resource, Authenticator(), ratelimiter
+ )
+
+ def test_complexity_simple(self):
+
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+ )
+
+ # Get the room complexity
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ complexity = channel.json_body["v1"]
+ self.assertTrue(complexity > 0, complexity)
+
+ # Artificially raise the complexity
+ store = self.hs.get_datastore()
+ store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23)
+
+ # Get the room complexity again -- make sure it's our artificial value
+ request, channel = self.make_request(
+ "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+ complexity = channel.json_body["v1"]
+ self.assertEqual(complexity, 1.23)
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
new file mode 100644
index 0000000000..2710c991cf
--- /dev/null
+++ b/tests/handlers/test_stats.py
@@ -0,0 +1,307 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests import unittest
+
+
+class StatsRoomTests(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+
+ self.store = hs.get_datastore()
+ self.handler = self.hs.get_stats_handler()
+
+ def _add_background_updates(self):
+ """
+ Add the background updates we need to run.
+ """
+ # Ugh, have to reset this flag
+ self.store._all_done = False
+
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_process_rooms",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_createtables",
+ },
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_rooms",
+ },
+ )
+ )
+
+ def test_initial_room(self):
+ """
+ The background updates will build the table from scratch.
+ """
+ r = self.get_success(self.store.get_all_room_state())
+ self.assertEqual(len(r), 0)
+
+ # Disable stats
+ self.hs.config.stats_enabled = False
+ self.handler.stats_enabled = False
+
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+ )
+
+ # Stats disabled, shouldn't have done anything
+ r = self.get_success(self.store.get_all_room_state())
+ self.assertEqual(len(r), 0)
+
+ # Enable stats
+ self.hs.config.stats_enabled = True
+ self.handler.stats_enabled = True
+
+ # Do the initial population of the user directory via the background update
+ self._add_background_updates()
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ r = self.get_success(self.store.get_all_room_state())
+
+ self.assertEqual(len(r), 1)
+ self.assertEqual(r[0]["topic"], "foo")
+
+ def test_initial_earliest_token(self):
+ """
+ Ingestion via notify_new_event will ignore tokens that the background
+ update have already processed.
+ """
+ self.reactor.advance(86401)
+
+ self.hs.config.stats_enabled = False
+ self.handler.stats_enabled = False
+
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ u2 = self.register_user("u2", "pass")
+ u2_token = self.login("u2", "pass")
+
+ u3 = self.register_user("u3", "pass")
+ u3_token = self.login("u3", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+ self.helper.send_state(
+ room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
+ )
+
+ # Begin the ingestion by creating the temp tables. This will also store
+ # the position that the deltas should begin at, once they take over.
+ self.hs.config.stats_enabled = True
+ self.handler.stats_enabled = True
+ self.store._all_done = False
+ self.get_success(self.store.update_stats_stream_pos(None))
+
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {"update_name": "populate_stats_createtables", "progress_json": "{}"},
+ )
+ )
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ # Now, before the table is actually ingested, add some more events.
+ self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room=room_1, user=u2, tok=u2_token)
+
+ # Now do the initial ingestion.
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
+ )
+ )
+ self.get_success(
+ self.store._simple_insert(
+ "background_updates",
+ {
+ "update_name": "populate_stats_cleanup",
+ "progress_json": "{}",
+ "depends_on": "populate_stats_process_rooms",
+ },
+ )
+ )
+
+ self.store._all_done = False
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ self.reactor.advance(86401)
+
+ # Now add some more events, triggering ingestion. Because of the stream
+ # position being set to before the events sent in the middle, a simpler
+ # implementation would reprocess those events, and say there were four
+ # users, not three.
+ self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token)
+ self.helper.join(room=room_1, user=u3, tok=u3_token)
+
+ # Get the deltas! There should be two -- day 1, and day 2.
+ r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
+
+ # The oldest has 2 joined members
+ self.assertEqual(r[-1]["joined_members"], 2)
+
+ # The newest has 3
+ self.assertEqual(r[0]["joined_members"], 3)
+
+ def test_incorrect_state_transition(self):
+ """
+ If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to
+ (JOIN, INVITE, LEAVE, BAN), an error is raised.
+ """
+ events = {
+ "a1": {"membership": Membership.LEAVE},
+ "a2": {"membership": "not a real thing"},
+ }
+
+ def get_event(event_id, allow_none=True):
+ m = Mock()
+ m.content = events[event_id]
+ d = defer.Deferred()
+ self.reactor.callLater(0.0, d.callback, m)
+ return d
+
+ def get_received_ts(event_id):
+ return defer.succeed(1)
+
+ self.store.get_received_ts = get_received_ts
+ self.store.get_event = get_event
+
+ deltas = [
+ {
+ "type": EventTypes.Member,
+ "state_key": "some_user",
+ "room_id": "room",
+ "event_id": "a1",
+ "prev_event_id": "a2",
+ "stream_id": 60,
+ }
+ ]
+
+ f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ self.assertEqual(
+ f.value.args[0], "'not a real thing' is not a valid prev_membership"
+ )
+
+ # And the other way...
+ deltas = [
+ {
+ "type": EventTypes.Member,
+ "state_key": "some_user",
+ "room_id": "room",
+ "event_id": "a2",
+ "prev_event_id": "a1",
+ "stream_id": 100,
+ }
+ ]
+
+ f = self.get_failure(self.handler._handle_deltas(deltas), ValueError)
+ self.assertEqual(
+ f.value.args[0], "'not a real thing' is not a valid membership"
+ )
+
+ def test_redacted_prev_event(self):
+ """
+ If the prev_event does not exist, then it is assumed to be a LEAVE.
+ """
+ u1 = self.register_user("u1", "pass")
+ u1_token = self.login("u1", "pass")
+
+ room_1 = self.helper.create_room_as(u1, tok=u1_token)
+
+ # Do the initial population of the user directory via the background update
+ self._add_background_updates()
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ events = {
+ "a1": None,
+ "a2": {"membership": Membership.JOIN},
+ }
+
+ def get_event(event_id, allow_none=True):
+ if events.get(event_id):
+ m = Mock()
+ m.content = events[event_id]
+ else:
+ m = None
+ d = defer.Deferred()
+ self.reactor.callLater(0.0, d.callback, m)
+ return d
+
+ def get_received_ts(event_id):
+ return defer.succeed(1)
+
+ self.store.get_received_ts = get_received_ts
+ self.store.get_event = get_event
+
+ deltas = [
+ {
+ "type": EventTypes.Member,
+ "state_key": "some_user:test",
+ "room_id": room_1,
+ "event_id": "a2",
+ "prev_event_id": "a1",
+ "stream_id": 100,
+ }
+ ]
+
+ # Handle our fake deltas, which has a user going from LEAVE -> JOIN.
+ self.get_success(self.handler._handle_deltas(deltas))
+
+ # One delta, with two joined members -- the room creator, and our fake
+ # user.
+ r = self.get_success(self.store.get_deltas_for_room(room_1, 0))
+ self.assertEqual(len(r), 1)
+ self.assertEqual(r[0]["joined_members"], 2)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index ed0ca079d9..4153da4da7 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -27,6 +27,7 @@ from twisted.web.http import HTTPChannel
from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS
+from synapse.config.homeserver import HomeServerConfig
from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.http.federation.matrix_federation_agent import (
MatrixFederationAgent,
@@ -52,11 +53,16 @@ class MatrixFederationAgentTests(TestCase):
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+ # for now, we disable cert verification for the test, since the cert we
+ # present will not be trusted. We should do better here, though.
+ config_dict = default_config("test", parse=False)
+ config_dict["federation_verify_certificates"] = False
+ config = HomeServerConfig()
+ config.parse_config_dict(config_dict)
+
self.agent = MatrixFederationAgent(
reactor=self.reactor,
- tls_client_options_factory=ClientTLSOptionsFactory(
- default_config("test", parse=True)
- ),
+ tls_client_options_factory=ClientTLSOptionsFactory(config),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver,
_well_known_cache=self.well_known_cache,
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index ee5f09041f..e5fc2fcd15 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -408,7 +408,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertEqual([], users_in_room)
- @unittest.DEBUG
def test_shutdown_room_block_peek(self):
"""Test that a world_readable room can no longer be peeked into after
it has been shut down.
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 769c37ce52..72c7ed93cb 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -14,6 +14,8 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
+import json
+
from mock import Mock
from twisted.internet import defer
@@ -28,11 +30,14 @@ from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
-PATH_PREFIX = "/_matrix/client/api/v1"
+PATH_PREFIX = "/_matrix/client/r0"
+
+class MockHandlerProfileTestCase(unittest.TestCase):
+ """ Tests rest layer of profile management.
-class ProfileTestCase(unittest.TestCase):
- """ Tests profile management. """
+ Todo: move these into ProfileTestCase
+ """
@defer.inlineCallbacks
def setUp(self):
@@ -159,6 +164,59 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")
+class ProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ profile.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.hs = self.setup_test_homeserver()
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.owner = self.register_user("owner", "pass")
+ self.owner_tok = self.login("owner", "pass")
+
+ def test_set_displayname(self):
+ request, channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.owner, ),
+ content=json.dumps({"displayname": "test"}),
+ access_token=self.owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ res = self.get_displayname()
+ self.assertEqual(res, "test")
+
+ def test_set_displayname_too_long(self):
+ """Attempts to set a stupid displayname should get a 400"""
+ request, channel = self.make_request(
+ "PUT",
+ "/profile/%s/displayname" % (self.owner, ),
+ content=json.dumps({"displayname": "test" * 100}),
+ access_token=self.owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 400, channel.result)
+
+ res = self.get_displayname()
+ self.assertEqual(res, "owner")
+
+ def get_displayname(self):
+ request, channel = self.make_request(
+ "GET",
+ "/profile/%s/displayname" % (self.owner, ),
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["displayname"]
+
+
class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 05b0143c42..f7133fc12e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -127,3 +127,20 @@ class RestHelper(object):
)
return channel.json_body
+
+ def send_state(self, room_id, event_type, body, tok, expect_code=200):
+ path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type)
+ if tok:
+ path = path + "?access_token=%s" % tok
+
+ request, channel = make_request(
+ self.hs.get_reactor(), "PUT", path, json.dumps(body).encode('utf8')
+ )
+ render(request, self.resource, self.hs.get_reactor())
+
+ assert int(channel.result["code"]) == expect_code, (
+ "Expected: %d, got: %d, resp: %r"
+ % (expect_code, int(channel.result["code"]), channel.result["body"])
+ )
+
+ return channel.json_body
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index f3ef977404..bce5b0cf4c 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.rest.admin
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import capabilities
@@ -32,6 +32,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.url = b"/_matrix/client/r0/capabilities"
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
+ self.config = hs.config
return hs
def test_check_auth_required(self):
@@ -51,8 +52,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
for room_version in capabilities['m.room_versions']['available'].keys():
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
+
self.assertEqual(
- DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default']
+ self.config.default_room_version.identifier,
+ capabilities['m.room_versions']['default'],
)
def test_get_change_password_capabilities(self):
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 65685883db..0cb6a363d6 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,3 +1,20 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import datetime
import json
import os
@@ -409,3 +426,46 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEqual(len(self.email_attempts), 1)
+
+
+class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.validity_period = 10
+ self.max_delta = self.validity_period * 10. / 100.
+
+ config = self.default_config()
+
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": False,
+ }
+
+ self.hs = self.setup_test_homeserver(config=config)
+ self.hs.config.account_validity.period = self.validity_period
+
+ self.store = self.hs.get_datastore()
+
+ return self.hs
+
+ def test_background_job(self):
+ """
+ Tests the same thing as test_background_job, except that it sets the
+ startup_job_max_delta parameter and checks that the expiration date is within the
+ allowed range.
+ """
+ user_id = self.register_user("kermit_delta", "user")
+
+ self.hs.config.account_validity.startup_job_max_delta = self.max_delta
+
+ now_ms = self.hs.clock.time_msec()
+ self.get_success(self.store._set_expiration_date_when_missing())
+
+ res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+
+ self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
+ self.assertLessEqual(res, now_ms + self.validity_period)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 3d040cf118..43b3049daa 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -90,6 +90,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
self.assertEquals(400, channel.code, channel.json_body)
+ def test_deny_double_react(self):
+ """Test that we deny relations on membership events
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self.assertEquals(400, channel.code, channel.json_body)
+
def test_basic_paginate_relations(self):
"""Tests that calling pagination API corectly the latest relations.
"""
@@ -234,14 +243,30 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"""Test that we can paginate within an annotation group.
"""
+ # We need to create ten separate users to send each reaction.
+ access_tokens = [self.user_token, self.user2_token]
+ idx = 0
+ while len(access_tokens) < 10:
+ user_id, token = self._create_user("test" + str(idx))
+ idx += 1
+
+ self.helper.join(self.room, user=user_id, tok=token)
+ access_tokens.append(token)
+
+ idx = 0
expected_event_ids = []
for _ in range(10):
channel = self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", key=u"👍"
+ RelationTypes.ANNOTATION,
+ "m.reaction",
+ key=u"👍",
+ access_token=access_tokens[idx],
)
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
+ idx += 1
+
# Also send a different type of reaction so that we test we don't see it
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
new file mode 100644
index 0000000000..6dda66ecd3
--- /dev/null
+++ b/tests/storage/test_cleanup_extrems.py
@@ -0,0 +1,248 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path
+
+from synapse.api.constants import EventTypes
+from synapse.storage import prepare_database
+from synapse.types import Requester, UserID
+
+from tests.unittest import HomeserverTestCase
+
+
+class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
+ """Test the background update to clean forward extremities table.
+ """
+
+ def prepare(self, reactor, clock, homeserver):
+ self.store = homeserver.get_datastore()
+ self.event_creator = homeserver.get_event_creation_handler()
+ self.room_creator = homeserver.get_room_creation_handler()
+
+ # Create a test user and room
+ self.user = UserID("alice", "test")
+ self.requester = Requester(self.user, None, False, None, None)
+ info = self.get_success(self.room_creator.create_room(self.requester, {}))
+ self.room_id = info["room_id"]
+
+ def create_and_send_event(self, soft_failed=False, prev_event_ids=None):
+ """Create and send an event.
+
+ Args:
+ soft_failed (bool): Whether to create a soft failed event or not
+ prev_event_ids (list[str]|None): Explicitly set the prev events,
+ or if None just use the default
+
+ Returns:
+ str: The new event's ID.
+ """
+ prev_events_and_hashes = None
+ if prev_event_ids:
+ prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
+
+ event, context = self.get_success(
+ self.event_creator.create_event(
+ self.requester,
+ {
+ "type": EventTypes.Message,
+ "room_id": self.room_id,
+ "sender": self.user.to_string(),
+ "content": {"body": "", "msgtype": "m.text"},
+ },
+ prev_events_and_hashes=prev_events_and_hashes,
+ )
+ )
+
+ if soft_failed:
+ event.internal_metadata.soft_failed = True
+
+ self.get_success(
+ self.event_creator.send_nonmember_event(self.requester, event, context)
+ )
+
+ return event.event_id
+
+ def add_extremity(self, event_id):
+ """Add the given event as an extremity to the room.
+ """
+ self.get_success(
+ self.store._simple_insert(
+ table="event_forward_extremities",
+ values={"room_id": self.room_id, "event_id": event_id},
+ desc="test_add_extremity",
+ )
+ )
+
+ self.store.get_latest_event_ids_in_room.invalidate((self.room_id,))
+
+ def run_background_update(self):
+ """Re run the background update to clean up the extremities.
+ """
+ # Make sure we don't clash with in progress updates.
+ self.assertTrue(self.store._all_done, "Background updates are still ongoing")
+
+ schema_path = os.path.join(
+ prepare_database.dir_path,
+ "schema",
+ "delta",
+ "54",
+ "delete_forward_extremities.sql",
+ )
+
+ def run_delta_file(txn):
+ prepare_database.executescript(txn, schema_path)
+
+ self.get_success(
+ self.store.runInteraction("test_delete_forward_extremities", run_delta_file)
+ )
+
+ # Ugh, have to reset this flag
+ self.store._all_done = False
+
+ while not self.get_success(self.store.has_completed_background_updates()):
+ self.get_success(self.store.do_next_background_update(100), by=0.1)
+
+ def test_soft_failed_extremities_handled_correctly(self):
+ """Test that extremities are correctly calculated in the presence of
+ soft failed events.
+
+ Tests a graph like:
+
+ A <- SF1 <- SF2 <- B
+
+ Where SF* are soft failed.
+ """
+
+ # Create the room graph
+ event_id_1 = self.create_and_send_event()
+ event_id_2 = self.create_and_send_event(True, [event_id_1])
+ event_id_3 = self.create_and_send_event(True, [event_id_2])
+ event_id_4 = self.create_and_send_event(False, [event_id_3])
+
+ # Check the latest events are as expected
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+
+ self.assertEqual(latest_event_ids, [event_id_4])
+
+ def test_basic_cleanup(self):
+ """Test that extremities are correctly calculated in the presence of
+ soft failed events.
+
+ Tests a graph like:
+
+ A <- SF1 <- B
+
+ Where SF* are soft failed, and with extremities of A and B
+ """
+ # Create the room graph
+ event_id_a = self.create_and_send_event()
+ event_id_sf1 = self.create_and_send_event(True, [event_id_a])
+ event_id_b = self.create_and_send_event(False, [event_id_sf1])
+
+ # Add the new extremity and check the latest events are as expected
+ self.add_extremity(event_id_a)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+
+ # Run the background update and check it did the right thing
+ self.run_background_update()
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(latest_event_ids, [event_id_b])
+
+ def test_chain_of_fail_cleanup(self):
+ """Test that extremities are correctly calculated in the presence of
+ soft failed events.
+
+ Tests a graph like:
+
+ A <- SF1 <- SF2 <- B
+
+ Where SF* are soft failed, and with extremities of A and B
+ """
+ # Create the room graph
+ event_id_a = self.create_and_send_event()
+ event_id_sf1 = self.create_and_send_event(True, [event_id_a])
+ event_id_sf2 = self.create_and_send_event(True, [event_id_sf1])
+ event_id_b = self.create_and_send_event(False, [event_id_sf2])
+
+ # Add the new extremity and check the latest events are as expected
+ self.add_extremity(event_id_a)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+
+ # Run the background update and check it did the right thing
+ self.run_background_update()
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(latest_event_ids, [event_id_b])
+
+ def test_forked_graph_cleanup(self):
+ r"""Test that extremities are correctly calculated in the presence of
+ soft failed events.
+
+ Tests a graph like, where time flows down the page:
+
+ A B
+ / \ /
+ / \ /
+ SF1 SF2
+ | |
+ SF3 |
+ / \ |
+ | \ |
+ C SF4
+
+ Where SF* are soft failed, and with them A, B and C marked as
+ extremities. This should resolve to B and C being marked as extremity.
+ """
+ # Create the room graph
+ event_id_a = self.create_and_send_event()
+ event_id_b = self.create_and_send_event()
+ event_id_sf1 = self.create_and_send_event(True, [event_id_a])
+ event_id_sf2 = self.create_and_send_event(True, [event_id_a, event_id_b])
+ event_id_sf3 = self.create_and_send_event(True, [event_id_sf1])
+ self.create_and_send_event(True, [event_id_sf2, event_id_sf3]) # SF4
+ event_id_c = self.create_and_send_event(False, [event_id_sf3])
+
+ # Add the new extremity and check the latest events are as expected
+ self.add_extremity(event_id_a)
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(
+ set(latest_event_ids), set((event_id_a, event_id_b, event_id_c))
+ )
+
+ # Run the background update and check it did the right thing
+ self.run_background_update()
+
+ latest_event_ids = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
+ )
+ self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c]))
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 6bfaa00fe9..e07ff01201 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -17,6 +17,8 @@ import signedjson.key
from twisted.internet.defer import Deferred
+from synapse.storage.keys import FetchKeyResult
+
import tests.unittest
KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -31,23 +33,34 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self):
store = self.hs.get_datastore()
- d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
- self.get_success(d)
- d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
+ key_id_1 = "ed25519:key1"
+ key_id_2 = "ed25519:KEY_ID_2"
+ d = store.store_server_verify_keys(
+ "from_server",
+ 10,
+ [
+ ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
self.get_success(d)
d = store.get_server_verify_keys(
- [
- ("server1", "ed25519:key1"),
- ("server1", "ed25519:key2"),
- ("server1", "ed25519:key3"),
- ]
+ [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
)
res = self.get_success(d)
self.assertEqual(len(res.keys()), 3)
- self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
- self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
+ res1 = res[("server1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.verify_key.version, "key1")
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("server1", key_id_2)]
+ self.assertEqual(res2.verify_key, KEY_2)
+ # version comes from the ID it was stored with
+ self.assertEqual(res2.verify_key.version, "KEY_ID_2")
+ self.assertEqual(res2.valid_until_ts, 200)
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -60,32 +73,51 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
- d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
- self.get_success(d)
- d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
+ d = store.store_server_verify_keys(
+ "from_server",
+ 0,
+ [
+ ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
- self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+ res1 = res[("srv1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("srv1", key_id_2)]
+ self.assertEqual(res2.verify_key, KEY_2)
+ self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
res = store.get_server_verify_keys([("srv1", key_id_1)])
if isinstance(res, Deferred):
res = self.successResultOf(res)
self.assertEqual(len(res.keys()), 1)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+ self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
- d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
+ d = store.store_server_verify_keys(
+ "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
+ )
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
- self.assertEqual(res[("srv1", key_id_2)], new_key_2)
+
+ res1 = res[("srv1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("srv1", key_id_2)]
+ self.assertEqual(res2.verify_key, new_key_2)
+ self.assertEqual(res2.valid_until_ts, 300)
|