summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py2
-rw-r--r--tests/crypto/test_keyring.py225
-rw-r--r--tests/federation/test_complexity.py90
-rw-r--r--tests/rest/admin/test_admin.py1
-rw-r--r--tests/rest/client/v1/test_profile.py64
-rw-r--r--tests/rest/client/v2_alpha/test_register.py15
-rw-r--r--tests/storage/test_cleanup_extrems.py248
7 files changed, 595 insertions, 50 deletions
diff --git a/tests/__init__.py b/tests/__init__.py

index d3181f9403..f7fc502f01 100644 --- a/tests/__init__.py +++ b/tests/__init__.py
@@ -21,4 +21,4 @@ import tests.patch_inline_callbacks # attempt to do the patch before we load any synapse code tests.patch_inline_callbacks.do_patch() -util.DEFAULT_TIMEOUT_DURATION = 10 +util.DEFAULT_TIMEOUT_DURATION = 20 diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index de61bad15d..4cff7e36c8 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -19,16 +19,13 @@ from mock import Mock import canonicaljson import signedjson.key import signedjson.sign +from signedjson.key import get_verify_key from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.crypto import keyring -from synapse.crypto.keyring import ( - KeyLookupError, - PerspectivesKeyFetcher, - ServerKeyFetcher, -) +from synapse.crypto.keyring import PerspectivesKeyFetcher, ServerKeyFetcher from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext from synapse.util.logcontext import LoggingContext @@ -55,11 +52,11 @@ class MockPerspectiveServer(object): key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)} }, } - return self.get_signed_response(res) + self.sign_response(res) + return res - def get_signed_response(self, res): + def sign_response(self, res): signedjson.sign.sign_json(res, self.server_name, self.key) - return res class KeyringTestCase(unittest.HomeserverTestCase): @@ -85,7 +82,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # we run the lookup in a logcontext so that the patched inlineCallbacks can check # it is doing the right thing with logcontexts. wait_1_deferred = run_in_context( - kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred} + kr.wait_for_previous_lookups, {"server1": lookup_1_deferred} ) # there were no previous lookups, so the deferred should be ready @@ -94,7 +91,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # set off another wait. It should block because the first lookup # hasn't yet completed. wait_2_deferred = run_in_context( - kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred} + kr.wait_for_previous_lookups, {"server1": lookup_2_deferred} ) self.assertFalse(wait_2_deferred.called) @@ -137,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): context_11.request = "11" res_deferreds = kr.verify_json_objects_for_server( - [("server10", json1), ("server11", {})] + [("server10", json1, 0), ("server11", {}, 0)] ) # the unsigned json should be rejected pretty quickly @@ -174,7 +171,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.http_client.post_json.return_value = defer.Deferred() res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1)] + [("server10", json1, 0)] ) res_deferreds_2[0].addBoth(self.check_context, None) yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) @@ -197,31 +194,108 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - key1_id = "%s:%s" % (key1.alg, key1.version) - r = self.hs.datastore.store_server_verify_keys( "server9", time.time() * 1000, - [ - ( - "server9", - key1_id, - FetchKeyResult(signedjson.key.get_verify_key(key1), 1000), - ), - ], + [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], ) self.get_success(r) + json1 = {} signedjson.sign.sign_json(json1, "server9", key1) # should fail immediately on an unsigned object - d = _verify_json_for_server(kr, "server9", {}) + d = _verify_json_for_server(kr, "server9", {}, 0) self.failureResultOf(d, SynapseError) - d = _verify_json_for_server(kr, "server9", json1) - self.assertFalse(d.called) + # should suceed on a signed object + d = _verify_json_for_server(kr, "server9", json1, 500) + # self.assertFalse(d.called) self.get_success(d) + def test_verify_json_dedupes_key_requests(self): + """Two requests for the same key should be deduped.""" + key1 = signedjson.key.generate_signing_key(1) + + def get_keys(keys_to_fetch): + # there should only be one request object (with the max validity) + self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + + return defer.succeed( + { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) + } + } + ) + + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock(side_effect=get_keys) + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) + + json1 = {} + signedjson.sign.sign_json(json1, "server1", key1) + + # the first request should succeed; the second should fail because the key + # has expired + results = kr.verify_json_objects_for_server( + [("server1", json1, 500), ("server1", json1, 1500)] + ) + self.assertEqual(len(results), 2) + self.get_success(results[0]) + e = self.get_failure(results[1], SynapseError).value + self.assertEqual(e.errcode, "M_UNAUTHORIZED") + self.assertEqual(e.code, 401) + + # there should have been a single call to the fetcher + mock_fetcher.get_keys.assert_called_once() + + def test_verify_json_falls_back_to_other_fetchers(self): + """If the first fetcher cannot provide a recent enough key, we fall back""" + key1 = signedjson.key.generate_signing_key(1) + + def get_keys1(keys_to_fetch): + self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + return defer.succeed( + { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800) + } + } + ) + + def get_keys2(keys_to_fetch): + self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + return defer.succeed( + { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) + } + } + ) + + mock_fetcher1 = keyring.KeyFetcher() + mock_fetcher1.get_keys = Mock(side_effect=get_keys1) + mock_fetcher2 = keyring.KeyFetcher() + mock_fetcher2.get_keys = Mock(side_effect=get_keys2) + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2)) + + json1 = {} + signedjson.sign.sign_json(json1, "server1", key1) + + results = kr.verify_json_objects_for_server( + [("server1", json1, 1200), ("server1", json1, 1500)] + ) + self.assertEqual(len(results), 2) + self.get_success(results[0]) + e = self.get_failure(results[1], SynapseError).value + self.assertEqual(e.errcode, "M_UNAUTHORIZED") + self.assertEqual(e.code, 401) + + # there should have been a single call to each fetcher + mock_fetcher1.get_keys.assert_called_once() + mock_fetcher2.get_keys.assert_called_once() + class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): @@ -238,7 +312,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): testkey = signedjson.key.generate_signing_key("ver1") testverifykey = signedjson.key.get_verify_key(testkey) testverifykey_id = "ed25519:ver1" - VALID_UNTIL_TS = 1000 + VALID_UNTIL_TS = 200 * 1000 # valid response response = { @@ -260,8 +334,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): self.http_client.get_json.side_effect = get_json - server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) k = keys[SERVER_NAME][testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) @@ -286,11 +360,11 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) ) - # change the server name: it should cause a rejection + # change the server name: the result should be ignored response["server_name"] = "OTHER_SERVER" - self.get_failure( - fetcher.get_keys(server_name_and_key_ids), KeyLookupError - ) + + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) + self.assertEqual(keys, {}) class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): @@ -326,9 +400,10 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): }, } - persp_resp = { - "server_keys": [self.mock_perspective_server.get_signed_response(response)] - } + # the response must be signed by both the origin server and the perspectives + # server. + signedjson.sign.sign_json(response, SERVER_NAME, testkey) + self.mock_perspective_server.sign_response(response) def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) @@ -337,12 +412,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # check that the request is for the expected key q = data["server_keys"] self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"]) - return persp_resp + return {"server_keys": [response]} self.http_client.post_json.side_effect = post_json - server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) self.assertIn(SERVER_NAME, keys) k = keys[SERVER_NAME][testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) @@ -365,9 +440,77 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual( bytes(res["key_json"]), - canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]), + canonicaljson.encode_canonical_json(response), ) + def test_invalid_perspectives_responses(self): + """Check that invalid responses from the perspectives server are rejected""" + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + SERVER_NAME = "server2" + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + def build_response(): + # valid response + response = { + "server_name": SERVER_NAME, + "old_verify_keys": {}, + "valid_until_ts": VALID_UNTIL_TS, + "verify_keys": { + testverifykey_id: { + "key": signedjson.key.encode_verify_key_base64(testverifykey) + } + }, + } + + # the response must be signed by both the origin server and the perspectives + # server. + signedjson.sign.sign_json(response, SERVER_NAME, testkey) + self.mock_perspective_server.sign_response(response) + return response + + def get_key_from_perspectives(response): + fetcher = PerspectivesKeyFetcher(self.hs) + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + + def post_json(destination, path, data, **kwargs): + self.assertEqual(destination, self.mock_perspective_server.server_name) + self.assertEqual(path, "/_matrix/key/v2/query") + return {"server_keys": [response]} + + self.http_client.post_json.side_effect = post_json + + return self.get_success(fetcher.get_keys(keys_to_fetch)) + + # start with a valid response so we can check we are testing the right thing + response = build_response() + keys = get_key_from_perspectives(response) + k = keys[SERVER_NAME][testverifykey_id] + self.assertEqual(k.verify_key, testverifykey) + + # remove the perspectives server's signature + response = build_response() + del response["signatures"][self.mock_perspective_server.server_name] + self.http_client.post_json.return_value = {"server_keys": [response]} + keys = get_key_from_perspectives(response) + self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig") + + # remove the origin server's signature + response = build_response() + del response["signatures"][SERVER_NAME] + self.http_client.post_json.return_value = {"server_keys": [response]} + keys = get_key_from_perspectives(response) + self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") + + +def get_key_id(key): + """Get the matrix ID tag for a given SigningKey or VerifyKey""" + return "%s:%s" % (key.alg, key.version) + @defer.inlineCallbacks def run_in_context(f, *args, **kwargs): @@ -379,14 +522,16 @@ def run_in_context(f, *args, **kwargs): defer.returnValue(rv) -def _verify_json_for_server(keyring, server_name, json_object): +def _verify_json_for_server(keyring, server_name, json_object, validity_time): """thin wrapper around verify_json_for_server which makes sure it is wrapped with the patched defer.inlineCallbacks. """ @defer.inlineCallbacks def v(): - rv1 = yield keyring.verify_json_for_server(server_name, json_object) + rv1 = yield keyring.verify_json_for_server( + server_name, json_object, validity_time + ) defer.returnValue(rv1) return run_in_context(v) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py new file mode 100644
index 0000000000..1e3e5aec66 --- /dev/null +++ b/tests/federation/test_complexity.py
@@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest + + +class RoomComplexityTests(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self, name='test'): + config = super(RoomComplexityTests, self).default_config(name=name) + config["limit_large_remote_room_joins"] = True + config["limit_large_remote_room_complexity"] = 0.05 + return config + + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return defer.succeed("otherserver.nottld") + + ratelimiter = FederationRateLimiter( + clock, + FederationRateLimitConfig( + window_size=1, + sleep_limit=1, + sleep_msec=1, + reject_limit=1000, + concurrent_requests=1000, + ), + ) + server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + def test_complexity_simple(self): + + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token + ) + + # Get the room complexity + request, channel = self.make_request( + "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code) + complexity = channel.json_body["v1"] + self.assertTrue(complexity > 0, complexity) + + # Artificially raise the complexity + store = self.hs.get_datastore() + store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23) + + # Get the room complexity again -- make sure it's our artificial value + request, channel = self.make_request( + "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code) + complexity = channel.json_body["v1"] + self.assertEqual(complexity, 1.23) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index ee5f09041f..e5fc2fcd15 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py
@@ -408,7 +408,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([], users_in_room) - @unittest.DEBUG def test_shutdown_room_block_peek(self): """Test that a world_readable room can no longer be peeked into after it has been shut down. diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 769c37ce52..72c7ed93cb 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py
@@ -14,6 +14,8 @@ # limitations under the License. """Tests REST events for /profile paths.""" +import json + from mock import Mock from twisted.internet import defer @@ -28,11 +30,14 @@ from tests import unittest from ....utils import MockHttpResource, setup_test_homeserver myid = "@1234ABCD:test" -PATH_PREFIX = "/_matrix/client/api/v1" +PATH_PREFIX = "/_matrix/client/r0" + +class MockHandlerProfileTestCase(unittest.TestCase): + """ Tests rest layer of profile management. -class ProfileTestCase(unittest.TestCase): - """ Tests profile management. """ + Todo: move these into ProfileTestCase + """ @defer.inlineCallbacks def setUp(self): @@ -159,6 +164,59 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") +class ProfileTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + return self.hs + + def prepare(self, reactor, clock, hs): + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + + def test_set_displayname(self): + request, channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner, ), + content=json.dumps({"displayname": "test"}), + access_token=self.owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + res = self.get_displayname() + self.assertEqual(res, "test") + + def test_set_displayname_too_long(self): + """Attempts to set a stupid displayname should get a 400""" + request, channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner, ), + content=json.dumps({"displayname": "test" * 100}), + access_token=self.owner_tok, + ) + self.render(request) + self.assertEqual(channel.code, 400, channel.result) + + res = self.get_displayname() + self.assertEqual(res, "owner") + + def get_displayname(self): + request, channel = self.make_request( + "GET", + "/profile/%s/displayname" % (self.owner, ), + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["displayname"] + + class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index d4a1d4d50c..0cb6a363d6 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -436,6 +436,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.validity_period = 10 + self.max_delta = self.validity_period * 10. / 100. config = self.default_config() @@ -453,14 +454,18 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): def test_background_job(self): """ - Tests whether the account validity startup background job does the right thing, - which is sticking an expiration date to every account that doesn't already have - one. + Tests the same thing as test_background_job, except that it sets the + startup_job_max_delta parameter and checks that the expiration date is within the + allowed range. """ - user_id = self.register_user("kermit", "user") + user_id = self.register_user("kermit_delta", "user") + + self.hs.config.account_validity.startup_job_max_delta = self.max_delta now_ms = self.hs.clock.time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - self.assertEqual(res, now_ms + self.validity_period) + + self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) + self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py new file mode 100644
index 0000000000..6dda66ecd3 --- /dev/null +++ b/tests/storage/test_cleanup_extrems.py
@@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path + +from synapse.api.constants import EventTypes +from synapse.storage import prepare_database +from synapse.types import Requester, UserID + +from tests.unittest import HomeserverTestCase + + +class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): + """Test the background update to clean forward extremities table. + """ + + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() + self.event_creator = homeserver.get_event_creation_handler() + self.room_creator = homeserver.get_room_creation_handler() + + # Create a test user and room + self.user = UserID("alice", "test") + self.requester = Requester(self.user, None, False, None, None) + info = self.get_success(self.room_creator.create_room(self.requester, {})) + self.room_id = info["room_id"] + + def create_and_send_event(self, soft_failed=False, prev_event_ids=None): + """Create and send an event. + + Args: + soft_failed (bool): Whether to create a soft failed event or not + prev_event_ids (list[str]|None): Explicitly set the prev events, + or if None just use the default + + Returns: + str: The new event's ID. + """ + prev_events_and_hashes = None + if prev_event_ids: + prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] + + event, context = self.get_success( + self.event_creator.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.user.to_string(), + "content": {"body": "", "msgtype": "m.text"}, + }, + prev_events_and_hashes=prev_events_and_hashes, + ) + ) + + if soft_failed: + event.internal_metadata.soft_failed = True + + self.get_success( + self.event_creator.send_nonmember_event(self.requester, event, context) + ) + + return event.event_id + + def add_extremity(self, event_id): + """Add the given event as an extremity to the room. + """ + self.get_success( + self.store._simple_insert( + table="event_forward_extremities", + values={"room_id": self.room_id, "event_id": event_id}, + desc="test_add_extremity", + ) + ) + + self.store.get_latest_event_ids_in_room.invalidate((self.room_id,)) + + def run_background_update(self): + """Re run the background update to clean up the extremities. + """ + # Make sure we don't clash with in progress updates. + self.assertTrue(self.store._all_done, "Background updates are still ongoing") + + schema_path = os.path.join( + prepare_database.dir_path, + "schema", + "delta", + "54", + "delete_forward_extremities.sql", + ) + + def run_delta_file(txn): + prepare_database.executescript(txn, schema_path) + + self.get_success( + self.store.runInteraction("test_delete_forward_extremities", run_delta_file) + ) + + # Ugh, have to reset this flag + self.store._all_done = False + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + def test_soft_failed_extremities_handled_correctly(self): + """Test that extremities are correctly calculated in the presence of + soft failed events. + + Tests a graph like: + + A <- SF1 <- SF2 <- B + + Where SF* are soft failed. + """ + + # Create the room graph + event_id_1 = self.create_and_send_event() + event_id_2 = self.create_and_send_event(True, [event_id_1]) + event_id_3 = self.create_and_send_event(True, [event_id_2]) + event_id_4 = self.create_and_send_event(False, [event_id_3]) + + # Check the latest events are as expected + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + + self.assertEqual(latest_event_ids, [event_id_4]) + + def test_basic_cleanup(self): + """Test that extremities are correctly calculated in the presence of + soft failed events. + + Tests a graph like: + + A <- SF1 <- B + + Where SF* are soft failed, and with extremities of A and B + """ + # Create the room graph + event_id_a = self.create_and_send_event() + event_id_sf1 = self.create_and_send_event(True, [event_id_a]) + event_id_b = self.create_and_send_event(False, [event_id_sf1]) + + # Add the new extremity and check the latest events are as expected + self.add_extremity(event_id_a) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + + # Run the background update and check it did the right thing + self.run_background_update() + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(latest_event_ids, [event_id_b]) + + def test_chain_of_fail_cleanup(self): + """Test that extremities are correctly calculated in the presence of + soft failed events. + + Tests a graph like: + + A <- SF1 <- SF2 <- B + + Where SF* are soft failed, and with extremities of A and B + """ + # Create the room graph + event_id_a = self.create_and_send_event() + event_id_sf1 = self.create_and_send_event(True, [event_id_a]) + event_id_sf2 = self.create_and_send_event(True, [event_id_sf1]) + event_id_b = self.create_and_send_event(False, [event_id_sf2]) + + # Add the new extremity and check the latest events are as expected + self.add_extremity(event_id_a) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + + # Run the background update and check it did the right thing + self.run_background_update() + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(latest_event_ids, [event_id_b]) + + def test_forked_graph_cleanup(self): + r"""Test that extremities are correctly calculated in the presence of + soft failed events. + + Tests a graph like, where time flows down the page: + + A B + / \ / + / \ / + SF1 SF2 + | | + SF3 | + / \ | + | \ | + C SF4 + + Where SF* are soft failed, and with them A, B and C marked as + extremities. This should resolve to B and C being marked as extremity. + """ + # Create the room graph + event_id_a = self.create_and_send_event() + event_id_b = self.create_and_send_event() + event_id_sf1 = self.create_and_send_event(True, [event_id_a]) + event_id_sf2 = self.create_and_send_event(True, [event_id_a, event_id_b]) + event_id_sf3 = self.create_and_send_event(True, [event_id_sf1]) + self.create_and_send_event(True, [event_id_sf2, event_id_sf3]) # SF4 + event_id_c = self.create_and_send_event(False, [event_id_sf3]) + + # Add the new extremity and check the latest events are as expected + self.add_extremity(event_id_a) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual( + set(latest_event_ids), set((event_id_a, event_id_b, event_id_c)) + ) + + # Run the background update and check it did the right thing + self.run_background_update() + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c]))