diff options
Diffstat (limited to 'tests/storage')
26 files changed, 1610 insertions, 476 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index dd49a14524..5a50e4fdd4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached from tests import unittest -class CacheTestCase(unittest.TestCase): - def setUp(self): +class CacheTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): self.cache = Cache("test") def test_empty(self): @@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) -class CacheDecoratorTestCase(unittest.TestCase): +class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals(a.func("foo"), d.result) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks @@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=4) # HACK: This makes it 2 due to cache factor + @cached(max_entries=2) def func(self, key): callcount[0] += 1 return key @@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["hello"], ["there"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,13 +367,13 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "there")]), + {(1, "user1", "hello"), (2, "user2", "there")}, ) # Update only user2 @@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["bleb"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,11 +394,11 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "bleb")]), + {(1, "user1", "hello"), (2, "user2", "bleb")}, ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 622b16a071..ef296e7dab 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -24,10 +24,11 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.appservice import ( +from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -42,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_token = "token1" @@ -54,7 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + self.store = ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -65,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): pass def _add_appservice(self, as_token, id, url, hs_token, sender): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token=hs_token, - id=id, - sender_localpart=sender, - namespaces={}, - ) + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": hs_token, + "id": id, + "sender_localpart": sender, + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -106,12 +110,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] - self.db_pool = hs.get_db_pool() - self.engine = hs.database_engine - self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, @@ -123,17 +124,25 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(hs.get_db_conn(), hs) + # We assume there is only one database in these tests + database = hs.get_datastores().databases[0] + self.db_pool = database._db_pool + self.engine = database.engine - def _add_service(self, url, as_token, id): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token="something", - id=id, - sender_localpart="a_sender", - namespaces={}, + db_config = hs.config.get_single_database() + self.store = TestTransactionStore( + database, make_conn(db_config, self.engine), hs ) + + def _add_service(self, url, as_token, id): + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": "something", + "id": id, + "sender_localpart": "a_sender", + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -375,15 +384,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) self.assertEquals(2, len(services)) self.assertEquals( - set([self.as_list[2]["id"], self.as_list[0]["id"]]), - set([services[0].id, services[1].id]), + {self.as_list[2]["id"], self.as_list[0]["id"]}, + {services[0].id, services[1].id}, ) # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, db_conn, hs): - super(TestTransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TestTransactionStore, self).__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): @@ -413,10 +422,13 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -428,11 +440,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) @@ -449,11 +464,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 9fabe3fbc0..940b166129 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -2,74 +2,90 @@ from mock import Mock from twisted.internet import defer +from synapse.storage.background_updates import BackgroundUpdater + from tests import unittest -from tests.utils import setup_test_homeserver -class BackgroundUpdateTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - self.store = hs.get_datastore() - self.clock = hs.get_clock() +class BackgroundUpdateTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater + # the base test class should have run the real bg updates for us + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) self.update_handler = Mock() - - yield self.store.register_background_update_handler( + self.updates.register_background_update_handler( "test_update", self.update_handler ) - # run the real background updates, to get them out the way - # (perhaps we should run them as part of the test HS setup, since we - # run all of the other schema setup stuff there?) - while True: - res = yield self.store.do_next_background_update(1000) - if res is None: - break - - @defer.inlineCallbacks def test_do_background_update(self): - desired_count = 1000 + # the time we claim each update takes duration_ms = 42 + # the target runtime for each bg update + target_background_update_duration_ms = 50000 + + store = self.hs.get_datastore() + self.get_success( + store.db.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": '{"my_key": 1}'}, + ) + ) + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): - self.clock.advance_time_msec(count * duration_ms) + yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.store.runInteraction( + yield store.db.runInteraction( "update_progress", - self.store._background_update_progress_txn, + self.updates._background_update_progress_txn, "test_update", progress, ) return count self.update_handler.side_effect = update - - yield self.store.start_background_update("test_update", {"my_key": 1}) - self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) - self.assertIsNotNone(result) + res = self.get_success( + self.updates.do_next_background_update( + target_background_update_duration_ms + ), + by=0.1, + ) + self.assertFalse(res) + + # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update + # we should now get run with a much bigger number of items to update @defer.inlineCallbacks def update(progress, count): - yield self.store._end_background_update("test_update") + self.assertEqual(progress, {"my_key": 2}) + self.assertAlmostEqual( + count, target_background_update_duration_ms / duration_ms, places=0, + ) + yield self.updates._end_background_update("test_update") return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) - self.assertIsNotNone(result) - self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) + self.assertFalse(result) + self.update_handler.assert_called_once() # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) - self.assertIsNone(result) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) + self.assertTrue(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index c778de1f0c..278961c331 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,6 +21,7 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from tests import unittest @@ -50,22 +51,25 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True - config.event_cache_size = 1 - config.database_config = {"name": "sqlite3"} - engine = create_engine(config.database_config) + config.caches = Mock() + config.caches.event_cache_size = 1 + hs = TestHomeServer("test", config=config) + + sqlite_config = {"name": "sqlite3"} + engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - hs = TestHomeServer( - "test", db_pool=self.db_pool, config=config, database_engine=fake_engine - ) - self.datastore = SQLBaseStore(None, hs) + db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db._db_pool = self.db_pool + + self.datastore = SQLBaseStore(db, None, hs) @defer.inlineCallbacks def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +81,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +96,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore._simple_select_one_onecol( + value = yield self.datastore.db.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +110,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +126,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +141,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore._simple_select_list( + ret = yield self.datastore.db.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +154,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +169,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +184,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_delete_one( + yield self.datastore.db.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index e9e2d5337c..43425c969a 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -14,7 +14,13 @@ # limitations under the License. import os.path +from unittest.mock import patch +from mock import Mock + +import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.rest.client.v1 import login, room from synapse.storage import prepare_database from synapse.types import Requester, UserID @@ -33,17 +39,21 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] def run_background_update(self): """Re run the background update to clean up the extremities. """ # Make sure we don't clash with in progress updates. - self.assertTrue(self.store._all_done, "Background updates are still ongoing") + self.assertTrue( + self.store.db.updates._all_done, "Background updates are still ongoing" + ) schema_path = os.path.join( prepare_database.dir_path, + "data_stores", + "main", "schema", "delta", "54", @@ -54,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.runInteraction("test_delete_forward_extremities", run_delta_file) + self.store.db.runInteraction( + "test_delete_forward_extremities", run_delta_file + ) ) # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of @@ -118,7 +134,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -156,7 +172,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -211,9 +227,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual( - set(latest_event_ids), set((event_id_a, event_id_b, event_id_c)) - ) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) # Run the background update and check it did the right thing self.run_background_update() @@ -221,10 +235,18 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c])) + self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) class CleanupExtremDummyEventsTestCase(HomeserverTestCase): + CONSENT_VERSION = "1" + EXTREMITIES_COUNT = 50 + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + def make_homeserver(self, reactor, clock): config = self.default_config() config["cleanup_extremities_with_dummy_events"] = True @@ -233,28 +255,39 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() self.room_creator = homeserver.get_room_creation_handler() + self.event_creator_handler = homeserver.get_event_creation_handler() # Create a test user and room - self.user = UserID("alice", "test") + self.user = UserID.from_string(self.register_user("user1", "password")) + self.token1 = self.login("user1", "password") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] + self.event_creator = homeserver.get_event_creation_handler() + homeserver.config.user_consent_version = self.CONSENT_VERSION def test_send_dummy_event(self): - # Create a bushy graph with 50 extremities. + self._create_extremity_rich_graph() - event_id_start = self.create_and_send_event(self.room_id, self.user) - - for _ in range(50): - self.create_and_send_event( - self.room_id, self.user, prev_event_ids=[event_id_start] - ) + # Pump the reactor repeatedly so that the background updates have a + # chance to run. + self.pump(10 * 60) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(len(latest_event_ids), 50) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) + def test_send_dummy_events_when_insufficient_power(self): + self._create_extremity_rich_graph() + # Criple power levels + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + body={"users": {str(self.user): -1}}, + tok=self.token1, + ) # Pump the reactor repeatedly so that the background updates have a # chance to run. self.pump(10 * 60) @@ -262,4 +295,108 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) + # Check that the room has not been pruned + self.assertTrue(len(latest_event_ids) > 10) + + # New user with regular levels + user2 = self.register_user("user2", "password") + token2 = self.login("user2", "password") + self.helper.join(self.room_id, user2, tok=token2) + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) + def test_send_dummy_event_without_consent(self): + self._create_extremity_rich_graph() + self._enable_consent_checking() + + # Pump the reactor repeatedly so that the background updates have a + # chance to run. Attempt to add dummy event with user that has not consented + # Check that dummy event send fails. + self.pump(10 * 60) + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT) + + # Create new user, and add consent + user2 = self.register_user("user2", "password") + token2 = self.login("user2", "password") + self.get_success( + self.store.user_set_consent_version(user2, self.CONSENT_VERSION) + ) + self.helper.join(self.room_id, user2, tok=token2) + + # Background updates should now cause a dummy event to be added to the graph + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) + def test_expiry_logic(self): + """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() + expires old entries correctly. + """ + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "1" + ] = 100000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "2" + ] = 200000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "3" + ] = 300000 + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() + # All entries within time frame + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 3, + ) + # Oldest room to expire + self.pump(1) + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 2, + ) + # All rooms to expire + self.pump(2) + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 0, + ) + + def _create_extremity_rich_graph(self): + """Helper method to create bushy graph on demand""" + + event_id_start = self.create_and_send_event(self.room_id, self.user) + + for _ in range(self.EXTREMITIES_COUNT): + self.create_and_send_event( + self.room_id, self.user, prev_event_ids=[event_id_start] + ) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(len(latest_event_ids), 50) + + def _enable_consent_checking(self): + """Helper method to enable consent checking""" + self.event_creator._block_events_without_consent_error = "No consent from user" + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creator._consent_uri_builder = consent_uri_builder diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 09305c3bf1..3b483bc7f0 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.unittest import override_config class ClientIpStoreTestCase(unittest.HomeserverTestCase): @@ -37,9 +38,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(12345678) user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -47,15 +52,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", - "access_token": "access_token", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, @@ -82,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -113,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -134,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) def test_disabled_monthly_active_user(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success( self.store.insert_client_ip( @@ -146,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_full(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 lots_of_users = 100 user_id = "@user:server" @@ -163,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) @@ -181,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_updating_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) @@ -201,6 +201,173 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + def test_devices_last_seen_bg_update(self): + # First make sure we have completed all updates. + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", device_id + ) + ) + # Force persisting to disk + self.reactor.advance(200) + + # But clear the associated entry in devices table + self.get_success( + self.store.db.simple_update( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues={"last_seen": None, "ip": None, "user_agent": None}, + desc="test_devices_last_seen_bg_update", + ) + ) + + # We should now get nulls when querying + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + r = result[(user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": device_id, + "ip": None, + "user_agent": None, + "last_seen": None, + }, + r, + ) + + # Register the background update to run again. + self.get_success( + self.store.db.simple_insert( + table="background_updates", + values={ + "update_name": "devices_last_seen", + "progress_json": "{}", + "depends_on": None, + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db.updates._all_done = False + + # Now let's actually drive the updates to completion + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + # We should now get the correct result again + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + r = result[(user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": device_id, + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 0, + }, + r, + ) + + def test_old_user_ips_pruned(self): + # First make sure we have completed all updates. + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", device_id + ) + ) + + # Force persisting to disk + self.reactor.advance(200) + + # We should see that in the DB + result = self.get_success( + self.store.db.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual( + result, + [ + { + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": device_id, + "last_seen": 0, + } + ], + ) + + # Now advance by a couple of months + self.reactor.advance(60 * 24 * 60 * 60) + + # We should get no results. + result = self.get_success( + self.store.db.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual(result, []) + + # But we should still get the correct values for the device + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + r = result[(user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": device_id, + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 0, + }, + r, + ) + class ClientIpAuthTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py new file mode 100644 index 0000000000..5a77c84962 --- /dev/null +++ b/tests/storage/test_database.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.database import make_tuple_comparison_clause +from synapse.storage.engines import BaseDatabaseEngine + +from tests import unittest + + +def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: + # returns a DatabaseEngine, circumventing the abc mechanism + # any kwargs are set as attributes on the class before instantiating it + t = type( + "TestBaseDatabaseEngine", + (BaseDatabaseEngine,), + dict(BaseDatabaseEngine.__dict__), + ) + # defeat the abc mechanism + t.__abstractmethods__ = set() + for k, v in kwargs.items(): + setattr(t, k, v) + return t(None, None) + + +class TupleComparisonClauseTestCase(unittest.TestCase): + def test_native_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=True) + clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)]) + self.assertEqual(clause, "(a,b) > (?,?)") + self.assertEqual(args, [1, 2]) + + def test_emulated_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=False) + clause, args = make_tuple_comparison_clause( + db_engine, [("a", 1), ("b", 2), ("c", 3)] + ) + self.assertEqual( + clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))" + ) + self.assertEqual(args, [1, 1, 2, 2, 3]) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 3cc18f9f1c..c2539b353a 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) @defer.inlineCallbacks - def test_get_devices_by_remote(self): + def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id @@ -81,63 +81,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_devices_by_remote( + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( "somehost", -1, limit=100 ) # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) - @defer.inlineCallbacks - def test_get_devices_by_remote_limited(self): - # Test breaking the update limit in 1, 101, and 1 device_id segments - - # first add one device - device_ids1 = ["device_id0"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"] - ) - - # then add 101 - device_ids2 = ["device_id" + str(i + 1) for i in range(101)] - yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"] - ) - - # then one more - device_ids3 = ["newdevice"] - yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"] - ) - - # - # now read them back. - # - - # first we should get a single update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", -1, limit=100 - ) - self._check_devices_in_updates(device_ids1, device_updates) - - # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self.assertEqual(len(device_updates), 0) - - # The 101 devices should've been cleared, so we should now just get one device - # update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100 - ) - self._check_devices_in_updates(device_ids3, device_updates) - def _check_devices_in_updates(self, expected_device_ids, device_updates): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) - received_device_ids = {update["device_id"] for update in device_updates} + received_device_ids = { + update["device_id"] for edu_type, update in device_updates + } self.assertEqual(received_device_ids, set(expected_device_ids)) @defer.inlineCallbacks diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py new file mode 100644 index 0000000000..35dafbb904 --- /dev/null +++ b/tests/storage/test_e2e_room_keys.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests import unittest + +# sample room_key data for use in the tests +room_key = { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": False, + "session_data": "SSBBTSBBIEZJU0gK", +} + + +class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) + self.store = hs.get_datastore() + return hs + + def test_room_keys_version_delete(self): + # test that deleting a room key backup deletes the keys + version1 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + ) + + self.get_success( + self.store.add_e2e_room_keys( + "user_id", version1, [("room", "session", room_key)] + ) + ) + + version2 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + ) + + self.get_success( + self.store.add_e2e_room_keys( + "user_id", version2, [("room", "session", room_key)] + ) + ) + + # make sure the keys were stored properly + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) + self.assertEqual(len(keys["rooms"]), 1) + + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) + self.assertEqual(len(keys["rooms"]), 1) + + # delete version1 + self.get_success(self.store.delete_e2e_room_keys_version("user_id", version1)) + + # make sure the key from version1 is gone, and the key from version2 is + # still there + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) + self.assertEqual(len(keys["rooms"]), 0) + + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) + self.assertEqual(len(keys["rooms"]), 1) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index c8ece15284..398d546280 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev) + self.assertDictContainsSubset(json, dev) @defer.inlineCallbacks def test_reupload_key(self): @@ -68,7 +68,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("device", res["user"]) dev = res["user"]["device"] self.assertDictContainsSubset( - {"keys": json, "device_display_name": "display_name"}, dev + {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev ) @defer.inlineCallbacks @@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.store_device("user2", "device1", None) yield self.store.store_device("user2", "device2", None) - yield self.store.set_e2e_device_keys("user1", "device1", now, "json11") - yield self.store.set_e2e_device_keys("user1", "device2", now, "json12") - yield self.store.set_e2e_device_keys("user2", "device1", now, "json21") - yield self.store.set_e2e_device_keys("user2", "device2", now, "json22") + yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) + yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) + yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) + yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) res = yield self.store.get_e2e_device_keys( (("user1", "device1"), ("user2", "device2")) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 86c7ac350d..3aeec0dc0f 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import tests.unittest import tests.utils -class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_get_prev_events_for_room(self): room_id = "@ROOM:local" @@ -57,21 +52,182 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): "(event_id, algorithm, hash) " "VALUES (?, 'sha256', ?)" ), - (event_id, b"ffff"), + (event_id, bytearray(b"ffff")), ) - for i in range(0, 11): - yield self.store.runInteraction("insert", insert_event, i) + for i in range(0, 20): + self.get_success(self.store.db.runInteraction("insert", insert_event, i)) - # this should get the last five and five others - r = yield self.store.get_prev_events_for_room(room_id) + # this should get the last ten + r = self.get_success(self.store.get_prev_events_for_room(room_id)) self.assertEqual(10, len(r)) - for i in range(0, 5): - el = r[i] - depth = el[2] - self.assertEqual(10 - i, depth) - - for i in range(5, 5): - el = r[i] - depth = el[2] - self.assertLessEqual(5, depth) + for i in range(0, 10): + self.assertEqual("$event_%i:local" % (19 - i), r[i]) + + def test_get_rooms_with_many_extremities(self): + room1 = "#room1" + room2 = "#room2" + room3 = "#room3" + + def insert_event(txn, i, room_id): + event_id = "$event_%i:local" % i + txn.execute( + ( + "INSERT INTO event_forward_extremities (room_id, event_id) " + "VALUES (?, ?)" + ), + (room_id, event_id), + ) + + for i in range(0, 20): + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room1) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room2) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room3) + ) + + # Test simple case + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [])) + self.assertEqual(len(r), 3) + + # Does filter work? + + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1])) + self.assertTrue(room2 in r) + self.assertTrue(room3 in r) + self.assertEqual(len(r), 2) + + r = self.get_success( + self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + ) + self.assertEqual(r, [room3]) + + # Does filter and limit work? + + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) + self.assertTrue(r == [room2] or r == [room3]) + + def test_auth_difference(self): + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ´ | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } + + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + + def insert_event(txn, event_id, stream_ordering): + + depth = depth_map[event_id] + + self.store.db.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.store.db.simple_insert_many_txn( + txn, + table="event_auth", + values=[ + {"event_id": event_id, "room_id": room_id, "auth_id": a} + for a in auth_graph[event_id] + ], + ) + + next_stream_ordering = 0 + for event_id in auth_graph: + next_stream_ordering += 1 + self.get_success( + self.store.db.runInteraction( + "insert", insert_event, event_id, next_stream_ordering + ) + ) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + self.assertSetEqual(difference, set()) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index f26ff57a18..a7b85004e5 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info = self.get_success(room_creator.create_room(requester, {})) + info, _ = self.get_success(room_creator.create_room(requester, {})) room_id = info["room_id"] last_event = None @@ -59,24 +59,22 @@ class ExtremStatisticsTestCase(HomeserverTestCase): ) ) - expected = set( - [ - b'synapse_forward_extremities_bucket{le="1.0"} 0.0', - b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 2.0', - b'synapse_forward_extremities_bucket{le="5.0"} 2.0', - b'synapse_forward_extremities_bucket{le="7.0"} 3.0', - b'synapse_forward_extremities_bucket{le="10.0"} 3.0', - b'synapse_forward_extremities_bucket{le="15.0"} 3.0', - b'synapse_forward_extremities_bucket{le="20.0"} 3.0', - b'synapse_forward_extremities_bucket{le="50.0"} 3.0', - b'synapse_forward_extremities_bucket{le="100.0"} 3.0', - b'synapse_forward_extremities_bucket{le="200.0"} 3.0', - b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b"synapse_forward_extremities_count 3.0", - b"synapse_forward_extremities_sum 10.0", - ] - ) + expected = { + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b"synapse_forward_extremities_count 3.0", + b"synapse_forward_extremities_sum 10.0", + } self.assertEqual(items, expected) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b114c6fb1d..b45bc9c115 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.persist_events_store = hs.get_datastores().persist_events @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): @@ -55,7 +56,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.runInteraction( + counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -74,20 +75,20 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield self.store.add_push_actions_to_staging( event.event_id, {user_id: action} ) - yield self.store.runInteraction( + yield self.store.db.runInteraction( "", - self.store._set_push_actions_for_event_and_users_txn, + self.persist_events_store._set_push_actions_for_event_and_users_txn, [(event, None)], [(event, None)], ) def _rotate(stream): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -116,7 +117,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store._simple_delete( + yield self.store.db.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +136,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store._simple_insert( + return self.store.db.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py new file mode 100644 index 0000000000..55e9ecf264 --- /dev/null +++ b/tests/storage/test_id_generators.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.storage.database import Database +from synapse.storage.util.id_generators import MultiWriterIdGenerator + +from tests.unittest import HomeserverTestCase +from tests.utils import USE_POSTGRES_FOR_TESTS + + +class MultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db = self.store.db # type: Database + + self.get_success(self.db.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db, + instance_name=instance_name, + table="foobar", + instance_column="instance_name", + id_column="stream_id", + sequence_name="foobar_seq", + ) + + return self.get_success(self.db.runWithConnection(_create)) + + def _insert_rows(self, instance_name: str, number: int): + def _insert(txn): + for _ in range(number): + txn.execute( + "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", + (instance_name,), + ) + + self.get_success(self.db.runInteraction("test_single_instance", _insert)) + + def test_empty(self): + """Test an ID generator against an empty database gives sensible + current positions. + """ + + id_gen = self._create_id_generator() + + # The table is empty so we expect an empty map for positions + self.assertEqual(id_gen.get_positions(), {}) + + def test_single_instance(self): + """Test that reads and writes from a single process are handled + correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(_get_next_async()) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) + + def test_multi_instance(self): + """Test that reads and writes from multiple processes are handled + correctly. + """ + self._insert_rows("first", 3) + self._insert_rows("second", 4) + + first_id_gen = self._create_id_generator("first") + second_id_gen = self._create_id_generator("second") + + self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(first_id_gen.get_current_token("first"), 3) + self.assertEqual(first_id_gen.get_current_token("second"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + async def _get_next_async(): + with await first_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 8) + + self.assertEqual( + first_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7}) + + # However the ID gen on the second instance won't have seen the update + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + + # ... but calling `get_next` on the second instance should give a unique + # stream ID + + async def _get_next_async(): + with await second_id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 9) + + self.assertEqual( + second_id_gen.get_positions(), {"first": 3, "second": 7} + ) + + self.get_success(_get_next_async()) + + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) + + # If the second ID gen gets told about the first, it correctly updates + second_id_gen.advance("first", 8) + self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) + + def test_get_next_txn(self): + """Test that the `get_next_txn` function works correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + # Try allocating a new ID gen and check that we only see position + # advanced after we leave the context manager. + + def _get_next_txn(txn): + stream_id = id_gen.get_next_txn(txn) + self.assertEqual(stream_id, 8) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token("master"), 7) + + self.get_success(self.db.runInteraction("test", _get_next_txn)) + + self.assertEqual(id_gen.get_positions(), {"master": 8}) + self.assertEqual(id_gen.get_current_token("master"), 8) diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index e07ff01201..95f309fbbc 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -14,6 +14,7 @@ # limitations under the License. import signedjson.key +import unpaddedbase64 from twisted.internet.defer import Deferred @@ -21,11 +22,17 @@ from synapse.storage.keys import FetchKeyResult import tests.unittest -KEY_1 = signedjson.key.decode_verify_key_base64( - "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" + +def decode_verify_key_base64(key_id: str, key_base64: str): + key_bytes = unpaddedbase64.decode_base64(key_base64) + return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) + + +KEY_1 = decode_verify_key_base64( + "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" ) -KEY_2 = signedjson.key.decode_verify_key_base64( - "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" +KEY_2 = decode_verify_key_base64( + "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" ) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py new file mode 100644 index 0000000000..ab0df5ea93 --- /dev/null +++ b/tests/storage/test_main.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Awesome Technologies Innovationslabor GmbH +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.types import UserID + +from tests import unittest +from tests.utils import setup_test_homeserver + + +class DataStoreTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + hs = yield setup_test_homeserver(self.addCleanup) + + self.store = hs.get_datastore() + + self.user = UserID.from_string("@abcde:test") + self.displayname = "Frank" + + @defer.inlineCallbacks + def test_get_users_paginate(self): + yield self.store.register_user(self.user.to_string(), "pass") + yield self.store.create_profile(self.user.localpart) + yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + + users, total = yield self.store.get_users_paginate( + 0, 10, name="bc", guests=False + ) + + self.assertEquals(1, total) + self.assertEquals(self.displayname, users.pop()["displayname"]) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 1494650d10..9c04e92577 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -19,152 +19,222 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from tests import unittest +from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 +def gen_3pids(count): + """Generate `count` threepids as a list.""" + return [ + {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count) + ] + + class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def default_config(self): + config = default_config("test") + + config.update({"limit_usage_by_mau": True, "max_mau_value": 50}) + + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) - hs = self.setup_test_homeserver() - self.store = hs.get_datastore() - hs.config.limit_usage_by_mau = True - hs.config.max_mau_value = 50 + return config + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() # Advance the clock a bit reactor.advance(FORTY_DAYS) - return hs - + @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): - self.hs.config.max_mau_value = 5 + threepids = self.hs.config.mau_limits_reserved_threepids + + # register three users, of which two have reserved 3pids, and a third + # which is a support user. user1 = "@user1:server" - user1_email = "user1@matrix.org" + user1_email = threepids[0]["address"] user2 = "@user2:server" - user2_email = "user2@matrix.org" + user2_email = threepids[1]["address"] user3 = "@user3:server" - user3_email = "user3@matrix.org" - - threepids = [ - {"medium": "email", "address": user1_email}, - {"medium": "email", "address": user2_email}, - {"medium": "email", "address": user3_email}, - ] - # -1 because user3 is a support user and does not count - user_num = len(threepids) - 1 - self.store.register_user(user_id=user1, password_hash=None) - self.store.register_user(user_id=user2, password_hash=None) - self.store.register_user( - user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT + self.get_success(self.store.register_user(user_id=user1)) + self.get_success(self.store.register_user(user_id=user2)) + self.get_success( + self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT) ) - self.pump() now = int(self.hs.get_clock().time_msec()) - self.store.user_add_threepid(user1, "email", user1_email, now, now) - self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.get_success( + self.store.user_add_threepid(user1, "email", user1_email, now, now) + ) + self.get_success( + self.store.user_add_threepid(user2, "email", user2_email, now, now) + ) - self.store.runInteraction( - "initialise", self.store._initialise_reserved_users, threepids + # XXX why are we doing this here? this function is only run at startup + # so it is odd to re-run it here. + self.get_success( + self.store.db.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) ) - self.pump() - active_count = self.store.get_monthly_active_count() + # the number of users we expect will be counted against the mau limit + # -1 because user3 is a support user and does not count + user_num = len(threepids) - 1 - # Test total counts, ensure user3 (support user) is not counted - self.assertEquals(self.get_success(active_count), user_num) + # Check the number of active users. Ensure user3 (support user) is not counted + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, user_num) - # Test user is marked as active - timestamp = self.store.user_last_seen_monthly_active(user1) - self.assertTrue(self.get_success(timestamp)) - timestamp = self.store.user_last_seen_monthly_active(user2) - self.assertTrue(self.get_success(timestamp)) + # Test each of the registered users is marked as active + timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) + self.assertGreater(timestamp, 0) + timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2)) + self.assertGreater(timestamp, 0) - # Test that users are never removed from the db. + # Test that users with reserved 3pids are not removed from the MAU table + # XXX some of this is redundant. poking things into the config shouldn't + # work, and in any case it's not obvious what we expect to happen when + # we advance the reactor. self.hs.config.max_mau_value = 0 - self.reactor.advance(FORTY_DAYS) + self.hs.config.max_mau_value = 5 - self.store.reap_monthly_active_users() - self.pump() + self.get_success(self.store.reap_monthly_active_users()) - active_count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(active_count), user_num) + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, user_num) - # Test that regular users are removed from the db + # Add some more users and check they are counted as active ru_count = 2 - self.store.upsert_monthly_active_user("@ru1:server") - self.store.upsert_monthly_active_user("@ru2:server") - self.pump() - active_count = self.store.get_monthly_active_count() - self.assertEqual(self.get_success(active_count), user_num + ru_count) - self.hs.config.max_mau_value = user_num - self.store.reap_monthly_active_users() - self.pump() + self.get_success(self.store.upsert_monthly_active_user("@ru1:server")) + self.get_success(self.store.upsert_monthly_active_user("@ru2:server")) + + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, user_num + ru_count) - active_count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(active_count), user_num) + # now run the reaper and check that the number of active users is reduced + # to max_mau_value + self.get_success(self.store.reap_monthly_active_users()) + + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, 3) def test_can_insert_and_count_mau(self): - count = self.store.get_monthly_active_count() - self.assertEqual(0, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) - self.store.upsert_monthly_active_user("@user:server") - self.pump() + d = self.store.upsert_monthly_active_user("@user:server") + self.get_success(d) - count = self.store.get_monthly_active_count() - self.assertEqual(1, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 1) def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = self.store.user_last_seen_monthly_active(user_id1) - self.assertFalse(self.get_success(result) == 0) + result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + self.assertNotEqual(result, 0) - self.store.upsert_monthly_active_user(user_id1) - self.store.upsert_monthly_active_user(user_id2) - self.pump() + self.get_success(self.store.upsert_monthly_active_user(user_id1)) + self.get_success(self.store.upsert_monthly_active_user(user_id2)) - result = self.store.user_last_seen_monthly_active(user_id1) - self.assertGreater(self.get_success(result), 0) + result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + self.assertGreater(result, 0) - result = self.store.user_last_seen_monthly_active(user_id3) - self.assertNotEqual(self.get_success(result), 0) + result = self.get_success(self.store.user_last_seen_monthly_active(user_id3)) + self.assertNotEqual(result, 0) + @override_config({"max_mau_value": 5}) def test_reap_monthly_active_users(self): - self.hs.config.max_mau_value = 5 initial_users = 10 for i in range(initial_users): - self.store.upsert_monthly_active_user("@user%d:server" % i) - self.pump() + self.get_success( + self.store.upsert_monthly_active_user("@user%d:server" % i) + ) - count = self.store.get_monthly_active_count() - self.assertTrue(self.get_success(count), initial_users) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, initial_users) - self.store.reap_monthly_active_users() - self.pump() - count = self.store.get_monthly_active_count() - self.assertEquals( - self.get_success(count), initial_users - self.hs.config.max_mau_value - ) + d = self.store.reap_monthly_active_users() + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, self.hs.config.max_mau_value) self.reactor.advance(FORTY_DAYS) - self.store.reap_monthly_active_users() - self.pump() - count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(count), 0) + d = self.store.reap_monthly_active_users() + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + + # Note that below says mau_limit (no s), this is the name of the config + # value, although it gets stored on the config object as mau_limits. + @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) + def test_reap_monthly_active_users_reserved_users(self): + """ Tests that reaping correctly handles reaping where reserved users are + present""" + threepids = self.hs.config.mau_limits_reserved_threepids + initial_users = len(threepids) + reserved_user_number = initial_users - 1 + for i in range(initial_users): + user = "@user%d:server" % i + email = "user%d@matrix.org" % i + + self.get_success(self.store.upsert_monthly_active_user(user)) + + # Need to ensure that the most recent entries in the + # monthly_active_users table are reserved + now = int(self.hs.get_clock().time_msec()) + if i != 0: + self.get_success( + self.store.register_user(user_id=user, password_hash=None) + ) + self.get_success( + self.store.user_add_threepid(user, "email", email, now, now) + ) + + d = self.store.db.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, initial_users) + + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEqual(len(users), reserved_user_number) + + d = self.store.reap_monthly_active_users() + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, self.hs.config.max_mau_value) def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list user_id = "@user_id:host" - self.store.register_user(user_id=user_id, password_hash=None, make_guest=True) + + d = self.store.register_user( + user_id=user_id, password_hash=None, make_guest=True + ) + self.get_success(d) + self.store.upsert_monthly_active_user = Mock() - self.store.populate_monthly_active_users(user_id) - self.pump() + + d = self.store.populate_monthly_active_users(user_id) + self.get_success(d) + self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): @@ -175,8 +245,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - self.store.populate_monthly_active_users("user_id") - self.pump() + d = self.store.populate_monthly_active_users("user_id") + self.get_success(d) + self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): @@ -186,80 +257,132 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) - self.store.populate_monthly_active_users("user_id") - self.pump() + + d = self.store.populate_monthly_active_users("user_id") + self.get_success(d) + self.store.upsert_monthly_active_user.assert_not_called() def test_get_reserved_real_user_account(self): # Test no reserved users, or reserved threepids - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), 0) - # Test reserved users but no registered users + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEqual(len(users), 0) + # Test reserved users but no registered users user1 = "@user1:example.com" user2 = "@user2:example.com" + user1_email = "user1@example.com" user2_email = "user2@example.com" threepids = [ {"medium": "email", "address": user1_email}, {"medium": "email", "address": user2_email}, ] + self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + d = self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) + self.get_success(d) - self.pump() - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), 0) + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEqual(len(users), 0) - # Test reserved registed users - self.store.register_user(user_id=user1, password_hash=None) - self.store.register_user(user_id=user2, password_hash=None) - self.pump() + # Test reserved registered users + self.get_success(self.store.register_user(user_id=user1, password_hash=None)) + self.get_success(self.store.register_user(user_id=user2, password_hash=None)) now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), len(threepids)) + + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEqual(len(users), len(threepids)) def test_support_user_not_add_to_mau_limits(self): support_user_id = "@support:test" - count = self.store.get_monthly_active_count() - self.pump() - self.assertEqual(self.get_success(count), 0) - self.store.register_user( + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + + d = self.store.register_user( user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) + self.get_success(d) - self.store.upsert_monthly_active_user(support_user_id) - count = self.store.get_monthly_active_count() - self.pump() - self.assertEqual(self.get_success(count), 0) + d = self.store.upsert_monthly_active_user(support_user_id) + self.get_success(d) - def test_track_monthly_users_without_cap(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = True - self.hs.config.max_mau_value = 1 # should not matter + d = self.store.get_monthly_active_count() + count = self.get_success(d) + self.assertEqual(count, 0) - count = self.store.get_monthly_active_count() - self.assertEqual(0, self.get_success(count)) + # Note that the max_mau_value setting should not matter. + @override_config( + {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} + ) + def test_track_monthly_users_without_cap(self): + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(0, count) - self.store.upsert_monthly_active_user("@user1:server") - self.store.upsert_monthly_active_user("@user2:server") - self.pump() + self.get_success(self.store.upsert_monthly_active_user("@user1:server")) + self.get_success(self.store.upsert_monthly_active_user("@user2:server")) - count = self.store.get_monthly_active_count() - self.assertEqual(2, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(2, count) + @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = False self.store.upsert_monthly_active_user = Mock() - self.store.populate_monthly_active_users("@user:sever") - self.pump() + self.get_success(self.store.populate_monthly_active_users("@user:sever")) self.store.upsert_monthly_active_user.assert_not_called() + + def test_get_monthly_active_count_by_service(self): + appservice1_user1 = "@appservice1_user1:example.com" + appservice1_user2 = "@appservice1_user2:example.com" + + appservice2_user1 = "@appservice2_user1:example.com" + native_user1 = "@native_user1:example.com" + + service1 = "service1" + service2 = "service2" + native = "native" + + self.get_success( + self.store.register_user( + user_id=appservice1_user1, password_hash=None, appservice_id=service1 + ) + ) + self.get_success( + self.store.register_user( + user_id=appservice1_user2, password_hash=None, appservice_id=service1 + ) + ) + self.get_success( + self.store.register_user( + user_id=appservice2_user1, password_hash=None, appservice_id=service2 + ) + ) + self.get_success( + self.store.register_user(user_id=native_user1, password_hash=None) + ) + + count = self.get_success(self.store.get_monthly_active_count_by_service()) + self.assertEqual(count, {}) + + self.get_success(self.store.upsert_monthly_active_user(native_user1)) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user1)) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user2)) + self.get_success(self.store.upsert_monthly_active_user(appservice2_user1)) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 4) + + d = self.store.get_monthly_active_count_by_service() + result = self.get_success(d) + + self.assertEqual(result[service1], 2) + self.assertEqual(result[service2], 1) + self.assertEqual(result[native], 1) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 45824bd3b2..9b6f7211ae 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.profile import ProfileStore from synapse.types import UserID from tests import unittest @@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(hs.get_db_conn(), hs) + self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index f671599cb8..b9fafaa1a6 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -40,23 +40,24 @@ class PurgeTests(HomeserverTestCase): third = self.helper.send(self.room_id, body="test3") last = self.helper.send(self.room_id, body="test4") - storage = self.hs.get_datastore() + store = self.hs.get_datastore() + storage = self.hs.get_storage() # Get the topological token - event = storage.get_topological_token_for_event(last["event_id"]) + event = store.get_topological_token_for_event(last["event_id"]) self.pump() event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = storage.purge_events.purge_history(self.room_id, event, True) self.pump() self.assertEqual(self.successResultOf(purge), None) # Try and get the events - get_first = storage.get_event(first["event_id"]) - get_second = storage.get_event(second["event_id"]) - get_third = storage.get_event(third["event_id"]) - get_last = storage.get_event(last["event_id"]) + get_first = store.get_event(first["event_id"]) + get_second = store.get_event(second["event_id"]) + get_third = store.get_event(third["event_id"]) + get_last = store.get_event(last["event_id"]) self.pump() # 1-3 should fail and last will succeed, meaning that 1-3 are deleted diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index deecfad9fb..db3667dc43 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -116,7 +117,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.store.persist_event(event, context)) + self.get_success(self.storage.persistence.persist_event(event, context)) + + return event def test_redact(self): self.get_success( @@ -235,8 +238,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def build(self, prev_event_ids): built_event = yield self._base_builder.build(prev_event_ids) - built_event.event_id = self._event_id - built_event._event_dict["event_id"] = self._event_id + + built_event._event_id = self._event_id + built_event._dict["event_id"] = self._event_id + assert built_event.event_id == self._event_id + return built_event @property @@ -261,7 +267,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.store.persist_event(event_1, context_1)) + self.get_success(self.storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -280,7 +286,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.store.persist_event(event_2, context_2)) + self.get_success(self.storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) @@ -335,7 +341,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -353,7 +359,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -361,3 +367,72 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) self.assert_dict({"content": {}}, json.loads(event_json)) + + def test_redact_redaction(self): + """Tests that we can redact a redaction and can fetch it again. + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + + first_redact_event = self.get_success( + self.inject_redaction( + self.room1, msg_event.event_id, self.u_alice, "Redacting message" + ) + ) + + self.get_success( + self.inject_redaction( + self.room1, + first_redact_event.event_id, + self.u_alice, + "Redacting redaction", + ) + ) + + # Now lets jump to the future where we have censored the redaction event + # in the DB. + self.reactor.advance(60 * 60 * 24 * 31) + + # We just want to check that fetching the event doesn't raise an exception. + self.get_success( + self.store.get_event(first_redact_event.event_id, allow_none=True) + ) + + def test_store_redacted_redaction(self): + """Tests that we can store a redacted redaction. + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "foo"}, + }, + ) + + redaction_event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + + self.get_success( + self.storage.persistence.persist_event(redaction_event, context) + ) + + # Now lets jump to the future where we have censored the redaction event + # in the DB. + self.reactor.advance(60 * 60 * 24 * 31) + + # We just want to check that fetching the event doesn't raise an exception. + self.get_success( + self.store.get_event(redaction_event.event_id, allow_none=True) + ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4578cc3b60..71a40a0a49 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() @@ -44,12 +43,14 @@ class RegistrationStoreTestCase(unittest.TestCase): # TODO(paul): Surely this field should be 'user_id', not 'name' "name": self.user_id, "password_hash": self.pwhash, + "admin": 0, "is_guest": 0, "consent_version": None, "consent_server_notice_sent": None, "appservice_id": None, "creation_ts": 1000, "user_type": None, + "deactivated": 0, }, (yield self.store.get_user_by_id(self.user_id)), ) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1bee45706f..3b78d48896 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes +from synapse.api.room_versions import RoomVersions from synapse.types import RoomAlias, RoomID, UserID from tests import unittest @@ -40,6 +41,7 @@ class RoomStoreTestCase(unittest.TestCase): self.room.to_string(), room_creator_user_id=self.u_creator.to_string(), is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks @@ -53,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_with_stats(self): + self.assertDictContainsSubset( + { + "room_id": self.room.to_string(), + "creator": self.u_creator.to_string(), + "public": True, + }, + (yield self.store.get_room_with_stats(self.room.to_string())), + ) + class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks @@ -62,17 +75,21 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") yield self.store.store_room( - self.room.to_string(), room_creator_user_id="@creator:text", is_public=True + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.store.persist_event( + yield self.storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 447a3c6ffb..5dd46005e6 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -16,13 +16,14 @@ from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import RoomVersions +from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room from synapse.types import Requester, UserID from tests import unittest +from tests.test_utils import event_injection +from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -39,13 +40,11 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor, clock, hs: TestHomeServer): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -54,33 +53,13 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # User elsewhere on another host self.u_charlie = UserID.from_string("@charlie:elsewhere") - def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - - self.get_success(self.store.persist_event(event, context)) - - return event - def test_one_member(self): # Alice creates the room, and is automatically joined self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) rooms_for_user = self.get_success( - self.store.get_rooms_for_user_where_membership_is( + self.store.get_rooms_for_local_user_where_membership_is( self.u_alice, [Membership.JOIN] ) ) @@ -137,6 +116,52 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) + def test_get_joined_users_from_context(self): + room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + bob_event = event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) + + # first, create a regular event + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) + + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + + # Regression test for #7376: create a state event whose key matches bob's + # user_id, but which is *not* a membership event, and persist that; then check + # that `get_joined_users_from_context` returns the correct users for the next event. + non_member_event = event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) + event, context = event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) + users = self.get_success( + self.store.get_joined_users_from_context(event, context) + ) + self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) + class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): @@ -145,8 +170,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now let's create a room, which will insert a membership user = UserID("alice", "test") @@ -155,7 +184,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", @@ -166,8 +195,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 5c2cf3c2db..0b88308ff4 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -43,7 +45,10 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") yield self.store.store_room( - self.room.to_string(), room_creator_user_id="@creator:text", is_public=True + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks @@ -63,7 +68,7 @@ class StateStoreTestCase(tests.unittest.TestCase): builder ) - yield self.store.persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) return event @@ -82,7 +87,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids( + state_group_map = yield self.storage.state.get_state_groups_ids( self.room, [e2.event_id] ) self.assertEqual(len(state_group_map), 1) @@ -101,7 +106,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) + state_group_map = yield self.storage.state.get_state_groups( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -141,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event(e5.event_id) + state = yield self.storage.state.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -157,21 +164,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -181,7 +188,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, @@ -199,7 +206,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -215,13 +222,18 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) + group_ids = yield self.storage.state.get_state_groups_ids( + room_id, [e5.event_id] + ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -237,8 +249,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -250,8 +265,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -267,8 +285,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -287,8 +308,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -304,8 +328,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -317,8 +344,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -331,9 +361,11 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### # deliberately remove e2 (room name) from the _state_group_cache - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, True) self.assertEqual(known_absent, set()) @@ -346,21 +378,23 @@ class StateStoreTestCase(tests.unittest.TestCase): ) state_dict_ids.pop((e2.type, e2.state_key)) - self.store._state_group_cache.invalidate(group) - self.store._state_group_cache.update( - sequence=self.store._state_group_cache.sequence, + self.state_datastore._state_group_cache.invalidate(group) + self.state_datastore._state_group_cache.update( + sequence=self.state_datastore._state_group_cache.sequence, key=group, value=state_dict_ids, # list fetched keys so it knows it's partial fetched_keys=((e1.type, e1.state_key),), ) - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) - self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) + self.assertEqual(known_absent, {(e1.type, e1.state_key)}) self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ @@ -369,8 +403,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -381,8 +418,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -394,8 +434,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -405,8 +448,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -424,8 +470,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -435,8 +484,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -448,8 +500,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -459,8 +514,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index a771d5af29..8e817e2c7f 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util.retryutils import MAX_RETRY_INTERVAL + from tests.unittest import HomeserverTestCase @@ -45,3 +47,12 @@ class TransactionStoreTestCase(HomeserverTestCase): """ d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) self.get_success(d) + + def test_large_destination_retry(self): + d = self.store.set_destination_retry_timings( + "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL + ) + self.get_success(d) + + d = self.store.get_destination_retry_timings("example.com") + self.get_success(d) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index d7d244ce97..6a545d2eb0 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from synapse.storage import UserDirectoryStore - from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) + self.store = self.hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. |