summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test__base.py22
-rw-r--r--tests/storage/test_appservice.py80
-rw-r--r--tests/storage/test_background_update.py76
-rw-r--r--tests/storage/test_base.py33
-rw-r--r--tests/storage/test_cleanup_extrems.py326
-rw-r--r--tests/storage/test_client_ips.py208
-rw-r--r--tests/storage/test_devices.py32
-rw-r--r--tests/storage/test_e2e_room_keys.py75
-rw-r--r--tests/storage/test_end_to_end_keys.py12
-rw-r--r--tests/storage/test_event_federation.py210
-rw-r--r--tests/storage/test_event_metrics.py80
-rw-r--r--tests/storage/test_event_push_actions.py12
-rw-r--r--tests/storage/test_keys.py15
-rw-r--r--tests/storage/test_monthly_active_users.py149
-rw-r--r--tests/storage/test_profile.py12
-rw-r--r--tests/storage/test_purge.py15
-rw-r--r--tests/storage/test_redaction.py296
-rw-r--r--tests/storage/test_registration.py41
-rw-r--r--tests/storage/test_room.py14
-rw-r--r--tests/storage/test_roommember.py174
-rw-r--r--tests/storage/test_state.py178
-rw-r--r--tests/storage/test_transactions.py19
-rw-r--r--tests/storage/test_user_directory.py4
23 files changed, 1590 insertions, 493 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py

index dd49a14524..e37260a820 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py
@@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals(a.func("foo"), d.result) + self.assertEquals(a.func("foo").result, d.result) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks @@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["hello"], ["there"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,13 +367,13 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "there")]), + {(1, "user1", "hello"), (2, "user2", "there")}, ) # Update only user2 @@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["bleb"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,11 +394,11 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "bleb")]), + {(1, "user1", "hello"), (2, "user2", "bleb")}, ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 25a6c89ef5..31710949a8 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -24,10 +24,11 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.appservice import ( +from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -54,7 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + self.store = ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -65,16 +69,16 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): pass def _add_appservice(self, as_token, id, url, hs_token, sender): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token=hs_token, - id=id, - sender_localpart=sender, - namespaces={}, - ) + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": hs_token, + "id": id, + "sender_localpart": sender, + "namespaces": {}, + } # use the token as the filename - with open(as_token, 'w') as outfile: + with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) @@ -109,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - self.db_pool = hs.get_db_pool() - self.engine = hs.database_engine - self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, @@ -123,19 +124,27 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(hs.get_db_conn(), hs) + # We assume there is only one database in these tests + database = hs.get_datastores().databases[0] + self.db_pool = database._db_pool + self.engine = database.engine - def _add_service(self, url, as_token, id): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token="something", - id=id, - sender_localpart="a_sender", - namespaces={}, + db_config = hs.config.get_single_database() + self.store = TestTransactionStore( + database, make_conn(db_config, self.engine), hs ) + + def _add_service(self, url, as_token, id): + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": "something", + "id": id, + "sender_localpart": "a_sender", + "namespaces": {}, + } # use the token as the filename - with open(as_token, 'w') as outfile: + with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) @@ -375,15 +384,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) self.assertEquals(2, len(services)) self.assertEquals( - set([self.as_list[2]["id"], self.as_list[0]["id"]]), - set([services[0].id, services[1].id]), + {self.as_list[2]["id"], self.as_list[0]["id"]}, + {services[0].id, services[1].id}, ) # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, db_conn, hs): - super(TestTransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TestTransactionStore, self).__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): @@ -416,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -432,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) @@ -453,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index fbb9302694..ae14fb407d 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py
@@ -2,74 +2,84 @@ from mock import Mock from twisted.internet import defer +from synapse.storage.background_updates import BackgroundUpdater + from tests import unittest -from tests.utils import setup_test_homeserver -class BackgroundUpdateTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - self.store = hs.get_datastore() - self.clock = hs.get_clock() +class BackgroundUpdateTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater + # the base test class should have run the real bg updates for us + self.assertTrue(self.updates.has_completed_background_updates()) self.update_handler = Mock() - - yield self.store.register_background_update_handler( + self.updates.register_background_update_handler( "test_update", self.update_handler ) - # run the real background updates, to get them out the way - # (perhaps we should run them as part of the test HS setup, since we - # run all of the other schema setup stuff there?) - while True: - res = yield self.store.do_next_background_update(1000) - if res is None: - break - - @defer.inlineCallbacks def test_do_background_update(self): - desired_count = 1000 + # the time we claim each update takes duration_ms = 42 + # the target runtime for each bg update + target_background_update_duration_ms = 50000 + # first step: make a bit of progress @defer.inlineCallbacks def update(progress, count): - self.clock.advance_time_msec(count * duration_ms) + yield self.clock.sleep((count * duration_ms) / 1000) progress = {"my_key": progress["my_key"] + 1} - yield self.store.runInteraction( + yield self.hs.get_datastore().db.runInteraction( "update_progress", - self.store._background_update_progress_txn, + self.updates._background_update_progress_txn, "test_update", progress, ) - defer.returnValue(count) + return count self.update_handler.side_effect = update - yield self.store.start_background_update("test_update", {"my_key": 1}) - + self.get_success( + self.updates.start_background_update("test_update", {"my_key": 1}) + ) self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) - self.assertIsNotNone(result) + res = self.get_success( + self.updates.do_next_background_update( + target_background_update_duration_ms + ), + by=0.1, + ) + self.assertIsNotNone(res) + + # on the first call, we should get run with the default background update size self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update + # we should now get run with a much bigger number of items to update @defer.inlineCallbacks def update(progress, count): - yield self.store._end_background_update("test_update") - defer.returnValue(count) + self.assertEqual(progress, {"my_key": 2}) + self.assertAlmostEqual( + count, target_background_update_duration_ms / duration_ms, places=0, + ) + yield self.updates._end_background_update("test_update") + return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) self.assertIsNotNone(result) - self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) + self.update_handler.assert_called_once() # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = self.get_success( + self.updates.do_next_background_update(target_background_update_duration_ms) + ) self.assertIsNone(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index c778de1f0c..cdee0a9e60 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py
@@ -21,6 +21,7 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from tests import unittest @@ -51,21 +52,23 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True config.event_cache_size = 1 - config.database_config = {"name": "sqlite3"} - engine = create_engine(config.database_config) + hs = TestHomeServer("test", config=config) + + sqlite_config = {"name": "sqlite3"} + engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - hs = TestHomeServer( - "test", db_pool=self.db_pool, config=config, database_engine=fake_engine - ) - self.datastore = SQLBaseStore(None, hs) + db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db._db_pool = self.db_pool + + self.datastore = SQLBaseStore(db, None, hs) @defer.inlineCallbacks def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +80,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +95,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore._simple_select_one_onecol( + value = yield self.datastore.db.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +109,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +125,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +140,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore._simple_select_list( + ret = yield self.datastore.db.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +153,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +168,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +183,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_delete_one( + yield self.datastore.db.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 6aa8b8b3c6..0e04b2cf92 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -14,8 +14,13 @@ # limitations under the License. import os.path +from unittest.mock import patch +from mock import Mock + +import synapse.rest.admin from synapse.api.constants import EventTypes +from synapse.rest.client.v1 import login, room from synapse.storage import prepare_database from synapse.types import Requester, UserID @@ -23,17 +28,12 @@ from tests.unittest import HomeserverTestCase class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): - """Test the background update to clean forward extremities table. """ - def make_homeserver(self, reactor, clock): - # Hack until we understand why test_forked_graph_cleanup fails with v4 - config = self.default_config() - config['default_room_version'] = '1' - return self.setup_test_homeserver(config=config) + Test the background update to clean forward extremities table. + """ def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() - self.event_creator = homeserver.get_event_creation_handler() self.room_creator = homeserver.get_room_creation_handler() # Create a test user and room @@ -42,64 +42,18 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): info = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] - def create_and_send_event(self, soft_failed=False, prev_event_ids=None): - """Create and send an event. - - Args: - soft_failed (bool): Whether to create a soft failed event or not - prev_event_ids (list[str]|None): Explicitly set the prev events, - or if None just use the default - - Returns: - str: The new event's ID. - """ - prev_events_and_hashes = None - if prev_event_ids: - prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] - - event, context = self.get_success( - self.event_creator.create_event( - self.requester, - { - "type": EventTypes.Message, - "room_id": self.room_id, - "sender": self.user.to_string(), - "content": {"body": "", "msgtype": "m.text"}, - }, - prev_events_and_hashes=prev_events_and_hashes, - ) - ) - - if soft_failed: - event.internal_metadata.soft_failed = True - - self.get_success( - self.event_creator.send_nonmember_event(self.requester, event, context) - ) - - return event.event_id - - def add_extremity(self, event_id): - """Add the given event as an extremity to the room. - """ - self.get_success( - self.store._simple_insert( - table="event_forward_extremities", - values={"room_id": self.room_id, "event_id": event_id}, - desc="test_add_extremity", - ) - ) - - self.store.get_latest_event_ids_in_room.invalidate((self.room_id,)) - def run_background_update(self): """Re run the background update to clean up the extremities. """ # Make sure we don't clash with in progress updates. - self.assertTrue(self.store._all_done, "Background updates are still ongoing") + self.assertTrue( + self.store.db.updates._all_done, "Background updates are still ongoing" + ) schema_path = os.path.join( prepare_database.dir_path, + "data_stores", + "main", "schema", "delta", "54", @@ -110,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.runInteraction("test_delete_forward_extremities", run_delta_file) + self.store.db.runInteraction( + "test_delete_forward_extremities", run_delta_file + ) ) # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of @@ -131,10 +91,16 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ # Create the room graph - event_id_1 = self.create_and_send_event() - event_id_2 = self.create_and_send_event(True, [event_id_1]) - event_id_3 = self.create_and_send_event(True, [event_id_2]) - event_id_4 = self.create_and_send_event(False, [event_id_3]) + event_id_1 = self.create_and_send_event(self.room_id, self.user) + event_id_2 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_1] + ) + event_id_3 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_2] + ) + event_id_4 = self.create_and_send_event( + self.room_id, self.user, False, [event_id_3] + ) # Check the latest events are as expected latest_event_ids = self.get_success( @@ -154,17 +120,21 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): Where SF* are soft failed, and with extremities of A and B """ # Create the room graph - event_id_a = self.create_and_send_event() - event_id_sf1 = self.create_and_send_event(True, [event_id_a]) - event_id_b = self.create_and_send_event(False, [event_id_sf1]) + event_id_a = self.create_and_send_event(self.room_id, self.user) + event_id_sf1 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a] + ) + event_id_b = self.create_and_send_event( + self.room_id, self.user, False, [event_id_sf1] + ) # Add the new extremity and check the latest events are as expected - self.add_extremity(event_id_a) + self.add_extremity(self.room_id, event_id_a) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -185,18 +155,24 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): Where SF* are soft failed, and with extremities of A and B """ # Create the room graph - event_id_a = self.create_and_send_event() - event_id_sf1 = self.create_and_send_event(True, [event_id_a]) - event_id_sf2 = self.create_and_send_event(True, [event_id_sf1]) - event_id_b = self.create_and_send_event(False, [event_id_sf2]) + event_id_a = self.create_and_send_event(self.room_id, self.user) + event_id_sf1 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a] + ) + event_id_sf2 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_sf1] + ) + event_id_b = self.create_and_send_event( + self.room_id, self.user, False, [event_id_sf2] + ) # Add the new extremity and check the latest events are as expected - self.add_extremity(event_id_a) + self.add_extremity(self.room_id, event_id_a) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -227,23 +203,31 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """ # Create the room graph - event_id_a = self.create_and_send_event() - event_id_b = self.create_and_send_event() - event_id_sf1 = self.create_and_send_event(True, [event_id_a]) - event_id_sf2 = self.create_and_send_event(True, [event_id_a, event_id_b]) - event_id_sf3 = self.create_and_send_event(True, [event_id_sf1]) - self.create_and_send_event(True, [event_id_sf2, event_id_sf3]) # SF4 - event_id_c = self.create_and_send_event(False, [event_id_sf3]) + event_id_a = self.create_and_send_event(self.room_id, self.user) + event_id_b = self.create_and_send_event(self.room_id, self.user) + event_id_sf1 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a] + ) + event_id_sf2 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_a, event_id_b] + ) + event_id_sf3 = self.create_and_send_event( + self.room_id, self.user, True, [event_id_sf1] + ) + self.create_and_send_event( + self.room_id, self.user, True, [event_id_sf2, event_id_sf3] + ) # SF4 + event_id_c = self.create_and_send_event( + self.room_id, self.user, False, [event_id_sf3] + ) # Add the new extremity and check the latest events are as expected - self.add_extremity(event_id_a) + self.add_extremity(self.room_id, event_id_a) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual( - set(latest_event_ids), set((event_id_a, event_id_b, event_id_c)) - ) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) # Run the background update and check it did the right thing self.run_background_update() @@ -251,4 +235,168 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c])) + self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) + + +class CleanupExtremDummyEventsTestCase(HomeserverTestCase): + CONSENT_VERSION = "1" + EXTREMITIES_COUNT = 50 + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["cleanup_extremities_with_dummy_events"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() + self.room_creator = homeserver.get_room_creation_handler() + self.event_creator_handler = homeserver.get_event_creation_handler() + + # Create a test user and room + self.user = UserID.from_string(self.register_user("user1", "password")) + self.token1 = self.login("user1", "password") + self.requester = Requester(self.user, None, False, None, None) + info = self.get_success(self.room_creator.create_room(self.requester, {})) + self.room_id = info["room_id"] + self.event_creator = homeserver.get_event_creation_handler() + homeserver.config.user_consent_version = self.CONSENT_VERSION + + def test_send_dummy_event(self): + self._create_extremity_rich_graph() + + # Pump the reactor repeatedly so that the background updates have a + # chance to run. + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) + def test_send_dummy_events_when_insufficient_power(self): + self._create_extremity_rich_graph() + # Criple power levels + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + body={"users": {str(self.user): -1}}, + tok=self.token1, + ) + # Pump the reactor repeatedly so that the background updates have a + # chance to run. + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + # Check that the room has not been pruned + self.assertTrue(len(latest_event_ids) > 10) + + # New user with regular levels + user2 = self.register_user("user2", "password") + token2 = self.login("user2", "password") + self.helper.join(self.room_id, user2, tok=token2) + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0) + def test_send_dummy_event_without_consent(self): + self._create_extremity_rich_graph() + self._enable_consent_checking() + + # Pump the reactor repeatedly so that the background updates have a + # chance to run. Attempt to add dummy event with user that has not consented + # Check that dummy event send fails. + self.pump(10 * 60) + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT) + + # Create new user, and add consent + user2 = self.register_user("user2", "password") + token2 = self.login("user2", "password") + self.get_success( + self.store.user_set_consent_version(user2, self.CONSENT_VERSION) + ) + self.helper.join(self.room_id, user2, tok=token2) + + # Background updates should now cause a dummy event to be added to the graph + self.pump(10 * 60) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids)) + + @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250) + def test_expiry_logic(self): + """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion() + expires old entries correctly. + """ + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "1" + ] = 100000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "2" + ] = 200000 + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ + "3" + ] = 300000 + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() + # All entries within time frame + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 3, + ) + # Oldest room to expire + self.pump(1) + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 2, + ) + # All rooms to expire + self.pump(2) + self.assertEqual( + len( + self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion + ), + 0, + ) + + def _create_extremity_rich_graph(self): + """Helper method to create bushy graph on demand""" + + event_id_start = self.create_and_send_event(self.room_id, self.user) + + for _ in range(self.EXTREMITIES_COUNT): + self.create_and_send_event( + self.room_id, self.user, prev_event_ids=[event_id_start] + ) + + latest_event_ids = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) + ) + self.assertEqual(len(latest_event_ids), 50) + + def _enable_consent_checking(self): + """Helper method to enable consent checking""" + self.event_creator._block_events_without_consent_error = "No consent from user" + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creator._consent_uri_builder = consent_uri_builder diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index b62eae7abc..bf674dd184 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(12345678) user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -47,15 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", - "access_token": "access_token", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, @@ -82,7 +85,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -94,11 +97,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): result, [ { - 'access_token': 'access_token', - 'ip': 'ip', - 'user_agent': 'user_agent', - 'device_id': None, - 'last_seen': 12345678000, + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": None, + "last_seen": 12345678000, } ], ) @@ -113,7 +116,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -125,11 +128,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): result, [ { - 'access_token': 'access_token', - 'ip': 'ip', - 'user_agent': 'user_agent', - 'device_id': None, - 'last_seen': 12345878000, + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": None, + "last_seen": 12345878000, } ], ) @@ -185,9 +188,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - self.get_success( - self.store.register(user_id=user_id, token="123", password_hash=None) - ) + self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) @@ -203,6 +204,173 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + def test_devices_last_seen_bg_update(self): + # First make sure we have completed all updates. + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", device_id + ) + ) + # Force persisting to disk + self.reactor.advance(200) + + # But clear the associated entry in devices table + self.get_success( + self.store.db.simple_update( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues={"last_seen": None, "ip": None, "user_agent": None}, + desc="test_devices_last_seen_bg_update", + ) + ) + + # We should now get nulls when querying + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + r = result[(user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": device_id, + "ip": None, + "user_agent": None, + "last_seen": None, + }, + r, + ) + + # Register the background update to run again. + self.get_success( + self.store.db.simple_insert( + table="background_updates", + values={ + "update_name": "devices_last_seen", + "progress_json": "{}", + "depends_on": None, + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db.updates._all_done = False + + # Now let's actually drive the updates to completion + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + # We should now get the correct result again + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + r = result[(user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": device_id, + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 0, + }, + r, + ) + + def test_old_user_ips_pruned(self): + # First make sure we have completed all updates. + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", device_id + ) + ) + + # Force persisting to disk + self.reactor.advance(200) + + # We should see that in the DB + result = self.get_success( + self.store.db.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual( + result, + [ + { + "access_token": "access_token", + "ip": "ip", + "user_agent": "user_agent", + "device_id": device_id, + "last_seen": 0, + } + ], + ) + + # Now advance by a couple of months + self.reactor.advance(60 * 24 * 60 * 60) + + # We should get no results. + result = self.get_success( + self.store.db.simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual(result, []) + + # But we should still get the correct values for the device + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + r = result[(user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": user_id, + "device_id": device_id, + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 0, + }, + r, + ) + class ClientIpAuthTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6396ccddb5..6f8d990959 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py
@@ -72,42 +72,42 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) @defer.inlineCallbacks - def test_get_devices_by_remote(self): + def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id yield self.store.add_device_change_to_streams( - "user_id", device_ids, ["somehost"], + "user_id", device_ids, ["somehost"] ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "somehost", -1, limit=100, + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( + "somehost", -1, limit=100 ) # Check original device_ids are contained within these updates self._check_devices_in_updates(device_ids, device_updates) @defer.inlineCallbacks - def test_get_devices_by_remote_limited(self): + def test_get_device_updates_by_remote_limited(self): # Test breaking the update limit in 1, 101, and 1 device_id segments # first add one device device_ids1 = ["device_id0"] yield self.store.add_device_change_to_streams( - "user_id", device_ids1, ["someotherhost"], + "user_id", device_ids1, ["someotherhost"] ) # then add 101 device_ids2 = ["device_id" + str(i + 1) for i in range(101)] yield self.store.add_device_change_to_streams( - "user_id", device_ids2, ["someotherhost"], + "user_id", device_ids2, ["someotherhost"] ) # then one more device_ids3 = ["newdevice"] yield self.store.add_device_change_to_streams( - "user_id", device_ids3, ["someotherhost"], + "user_id", device_ids3, ["someotherhost"] ) # @@ -115,21 +115,21 @@ class DeviceStoreTestCase(tests.unittest.TestCase): # # first we should get a single update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", -1, limit=100, + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( + "someotherhost", -1, limit=100 ) self._check_devices_in_updates(device_ids1, device_updates) # Then we should get an empty list back as the 101 devices broke the limit - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100, + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( + "someotherhost", now_stream_id, limit=100 ) self.assertEqual(len(device_updates), 0) # The 101 devices should've been cleared, so we should now just get one device # update - now_stream_id, device_updates = yield self.store.get_devices_by_remote( - "someotherhost", now_stream_id, limit=100, + now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( + "someotherhost", now_stream_id, limit=100 ) self._check_devices_in_updates(device_ids3, device_updates) @@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase): """Check that an specific device ids exist in a list of device update EDUs""" self.assertEqual(len(device_updates), len(expected_device_ids)) - received_device_ids = {update["device_id"] for update in device_updates} + received_device_ids = { + update["device_id"] for edu_type, update in device_updates + } self.assertEqual(received_device_ids, set(expected_device_ids)) @defer.inlineCallbacks diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py new file mode 100644
index 0000000000..35dafbb904 --- /dev/null +++ b/tests/storage/test_e2e_room_keys.py
@@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests import unittest + +# sample room_key data for use in the tests +room_key = { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": False, + "session_data": "SSBBTSBBIEZJU0gK", +} + + +class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) + self.store = hs.get_datastore() + return hs + + def test_room_keys_version_delete(self): + # test that deleting a room key backup deletes the keys + version1 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + ) + + self.get_success( + self.store.add_e2e_room_keys( + "user_id", version1, [("room", "session", room_key)] + ) + ) + + version2 = self.get_success( + self.store.create_e2e_room_keys_version( + "user_id", {"algorithm": "rot13", "auth_data": {}} + ) + ) + + self.get_success( + self.store.add_e2e_room_keys( + "user_id", version2, [("room", "session", room_key)] + ) + ) + + # make sure the keys were stored properly + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) + self.assertEqual(len(keys["rooms"]), 1) + + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) + self.assertEqual(len(keys["rooms"]), 1) + + # delete version1 + self.get_success(self.store.delete_e2e_room_keys_version("user_id", version1)) + + # make sure the key from version1 is gone, and the key from version2 is + # still there + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1)) + self.assertEqual(len(keys["rooms"]), 0) + + keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2)) + self.assertEqual(len(keys["rooms"]), 1) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index cd2bcd4ca3..398d546280 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py
@@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev) + self.assertDictContainsSubset(json, dev) @defer.inlineCallbacks def test_reupload_key(self): @@ -68,7 +68,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): self.assertIn("device", res["user"]) dev = res["user"]["device"] self.assertDictContainsSubset( - {"keys": json, "device_display_name": "display_name"}, dev + {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev ) @defer.inlineCallbacks @@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.store_device("user2", "device1", None) yield self.store.store_device("user2", "device2", None) - yield self.store.set_e2e_device_keys("user1", "device1", now, 'json11') - yield self.store.set_e2e_device_keys("user1", "device2", now, 'json12') - yield self.store.set_e2e_device_keys("user2", "device1", now, 'json21') - yield self.store.set_e2e_device_keys("user2", "device2", now, 'json22') + yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) + yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) + yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) + yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) res = yield self.store.get_e2e_device_keys( (("user1", "device1"), ("user2", "device2")) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 0d4e74d637..3aeec0dc0f 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py
@@ -13,25 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import tests.unittest import tests.utils -class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_get_prev_events_for_room(self): - room_id = '@ROOM:local' + room_id = "@ROOM:local" # add a bunch of events and hashes to act as forward extremities def insert_event(txn, i): - event_id = '$event_%i:local' % i + event_id = "$event_%i:local" % i txn.execute( ( @@ -45,33 +40,194 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): txn.execute( ( - 'INSERT INTO event_forward_extremities (room_id, event_id) ' - 'VALUES (?, ?)' + "INSERT INTO event_forward_extremities (room_id, event_id) " + "VALUES (?, ?)" ), (room_id, event_id), ) txn.execute( ( - 'INSERT INTO event_reference_hashes ' - '(event_id, algorithm, hash) ' + "INSERT INTO event_reference_hashes " + "(event_id, algorithm, hash) " "VALUES (?, 'sha256', ?)" ), - (event_id, b'ffff'), + (event_id, bytearray(b"ffff")), ) - for i in range(0, 11): - yield self.store.runInteraction("insert", insert_event, i) + for i in range(0, 20): + self.get_success(self.store.db.runInteraction("insert", insert_event, i)) - # this should get the last five and five others - r = yield self.store.get_prev_events_for_room(room_id) + # this should get the last ten + r = self.get_success(self.store.get_prev_events_for_room(room_id)) self.assertEqual(10, len(r)) - for i in range(0, 5): - el = r[i] - depth = el[2] - self.assertEqual(10 - i, depth) - - for i in range(5, 5): - el = r[i] - depth = el[2] - self.assertLessEqual(5, depth) + for i in range(0, 10): + self.assertEqual("$event_%i:local" % (19 - i), r[i]) + + def test_get_rooms_with_many_extremities(self): + room1 = "#room1" + room2 = "#room2" + room3 = "#room3" + + def insert_event(txn, i, room_id): + event_id = "$event_%i:local" % i + txn.execute( + ( + "INSERT INTO event_forward_extremities (room_id, event_id) " + "VALUES (?, ?)" + ), + (room_id, event_id), + ) + + for i in range(0, 20): + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room1) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room2) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room3) + ) + + # Test simple case + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [])) + self.assertEqual(len(r), 3) + + # Does filter work? + + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1])) + self.assertTrue(room2 in r) + self.assertTrue(room3 in r) + self.assertEqual(len(r), 2) + + r = self.get_success( + self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + ) + self.assertEqual(r, [room3]) + + # Does filter and limit work? + + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) + self.assertTrue(r == [room2] or r == [room3]) + + def test_auth_difference(self): + room_id = "@ROOM:local" + + # The silly auth graph we use to test the auth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ´ | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } + + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + + def insert_event(txn, event_id, stream_ordering): + + depth = depth_map[event_id] + + self.store.db.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + "stream_ordering": stream_ordering, + }, + ) + + self.store.db.simple_insert_many_txn( + txn, + table="event_auth", + values=[ + {"event_id": event_id, "room_id": room_id, "auth_id": a} + for a in auth_graph[event_id] + ], + ) + + next_stream_ordering = 0 + for event_id in auth_graph: + next_stream_ordering += 1 + self.get_success( + self.store.db.runInteraction( + "insert", insert_event, event_id, next_stream_ordering + ) + ) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + self.assertSetEqual(difference, set()) diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py new file mode 100644
index 0000000000..a7b7fd36d3 --- /dev/null +++ b/tests/storage/test_event_metrics.py
@@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.metrics import REGISTRY, generate_latest +from synapse.types import Requester, UserID + +from tests.unittest import HomeserverTestCase + + +class ExtremStatisticsTestCase(HomeserverTestCase): + def test_exposed_to_prometheus(self): + """ + Forward extremity counts are exposed via Prometheus. + """ + room_creator = self.hs.get_room_creation_handler() + + user = UserID("alice", "test") + requester = Requester(user, None, False, None, None) + + # Real events, forward extremities + events = [(3, 2), (6, 2), (4, 6)] + + for event_count, extrems in events: + info = self.get_success(room_creator.create_room(requester, {})) + room_id = info["room_id"] + + last_event = None + + # Make a real event chain + for i in range(event_count): + ev = self.create_and_send_event(room_id, user, False, last_event) + last_event = [ev] + + # Sprinkle in some extremities + for i in range(extrems): + ev = self.create_and_send_event(room_id, user, False, last_event) + + # Let it run for a while, then pull out the statistics from the + # Prometheus client registry + self.reactor.advance(60 * 60 * 1000) + self.pump(1) + + items = set( + filter( + lambda x: b"synapse_forward_extremities_" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + ) + + expected = { + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b"synapse_forward_extremities_count 3.0", + b"synapse_forward_extremities_sum 10.0", + } + + self.assertEqual(items, expected) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b114c6fb1d..d4bcf1821e 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py
@@ -55,7 +55,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.runInteraction( + counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -74,7 +74,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield self.store.add_push_actions_to_staging( event.event_id, {user_id: action} ) - yield self.store.runInteraction( + yield self.store.db.runInteraction( "", self.store._set_push_actions_for_event_and_users_txn, [(event, None)], @@ -82,12 +82,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) def _rotate(stream): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store._simple_delete( + yield self.store.db.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store._simple_insert( + return self.store.db.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index e07ff01201..95f309fbbc 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py
@@ -14,6 +14,7 @@ # limitations under the License. import signedjson.key +import unpaddedbase64 from twisted.internet.defer import Deferred @@ -21,11 +22,17 @@ from synapse.storage.keys import FetchKeyResult import tests.unittest -KEY_1 = signedjson.key.decode_verify_key_base64( - "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" + +def decode_verify_key_base64(key_id: str, key_base64: str): + key_bytes = unpaddedbase64.decode_base64(key_base64) + return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) + + +KEY_1 = decode_verify_key_base64( + "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" ) -KEY_2 = signedjson.key.decode_verify_key_base64( - "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" +KEY_2 = decode_verify_key_base64( + "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" ) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index f458c03054..bc53bf0951 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py
@@ -46,17 +46,18 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): user3_email = "user3@matrix.org" threepids = [ - {'medium': 'email', 'address': user1_email}, - {'medium': 'email', 'address': user2_email}, - {'medium': 'email', 'address': user3_email}, + {"medium": "email", "address": user1_email}, + {"medium": "email", "address": user2_email}, + {"medium": "email", "address": user3_email}, ] + self.hs.config.mau_limits_reserved_threepids = threepids # -1 because user3 is a support user and does not count user_num = len(threepids) - 1 - self.store.register(user_id=user1, token="123", password_hash=None) - self.store.register(user_id=user2, token="456", password_hash=None) - self.store.register( - user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT + self.store.register_user(user_id=user1, password_hash=None) + self.store.register_user(user_id=user2, password_hash=None) + self.store.register_user( + user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT ) self.pump() @@ -64,7 +65,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.pump() @@ -84,6 +85,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.hs.config.max_mau_value = 0 self.reactor.advance(FORTY_DAYS) + self.hs.config.max_mau_value = 5 self.store.reap_monthly_active_users() self.pump() @@ -147,9 +149,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.reap_monthly_active_users() self.pump() count = self.store.get_monthly_active_count() - self.assertEquals( - self.get_success(count), initial_users - self.hs.config.max_mau_value - ) + self.assertEquals(self.get_success(count), self.hs.config.max_mau_value) self.reactor.advance(FORTY_DAYS) self.store.reap_monthly_active_users() @@ -158,12 +158,48 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(count), 0) + def test_reap_monthly_active_users_reserved_users(self): + """ Tests that reaping correctly handles reaping where reserved users are + present""" + + self.hs.config.max_mau_value = 5 + initial_users = 5 + reserved_user_number = initial_users - 1 + threepids = [] + for i in range(initial_users): + user = "@user%d:server" % i + email = "user%d@example.com" % i + self.get_success(self.store.upsert_monthly_active_user(user)) + threepids.append({"medium": "email", "address": email}) + # Need to ensure that the most recent entries in the + # monthly_active_users table are reserved + now = int(self.hs.get_clock().time_msec()) + if i != 0: + self.get_success( + self.store.register_user(user_id=user, password_hash=None) + ) + self.get_success( + self.store.user_add_threepid(user, "email", email, now, now) + ) + + self.hs.config.mau_limits_reserved_threepids = threepids + self.store.db.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) + count = self.store.get_monthly_active_count() + self.assertTrue(self.get_success(count), initial_users) + + users = self.store.get_registered_reserved_users() + self.assertEquals(len(self.get_success(users)), reserved_user_number) + + self.get_success(self.store.reap_monthly_active_users()) + count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(count), self.hs.config.max_mau_value) + def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list user_id = "@user_id:host" - self.store.register( - user_id=user_id, token="123", password_hash=None, make_guest=True - ) + self.store.register_user(user_id=user_id, password_hash=None, make_guest=True) self.store.upsert_monthly_active_user = Mock() self.store.populate_monthly_active_users(user_id) self.pump() @@ -177,7 +213,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - self.store.populate_monthly_active_users('user_id') + self.store.populate_monthly_active_users("user_id") self.pump() self.store.upsert_monthly_active_user.assert_called_once() @@ -188,43 +224,45 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) - self.store.populate_monthly_active_users('user_id') + self.store.populate_monthly_active_users("user_id") self.pump() self.store.upsert_monthly_active_user.assert_not_called() def test_get_reserved_real_user_account(self): # Test no reserved users, or reserved threepids - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), 0) + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEquals(len(users), 0) # Test reserved users but no registered users - user1 = '@user1:example.com' - user2 = '@user2:example.com' - user1_email = 'user1@example.com' - user2_email = 'user2@example.com' + user1 = "@user1:example.com" + user2 = "@user2:example.com" + + user1_email = "user1@example.com" + user2_email = "user2@example.com" threepids = [ - {'medium': 'email', 'address': user1_email}, - {'medium': 'email', 'address': user2_email}, + {"medium": "email", "address": user1_email}, + {"medium": "email", "address": user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.pump() - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), 0) + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEquals(len(users), 0) # Test reserved registed users - self.store.register(user_id=user1, token="123", password_hash=None) - self.store.register(user_id=user2, token="456", password_hash=None) + self.store.register_user(user_id=user1, password_hash=None) + self.store.register_user(user_id=user2, password_hash=None) self.pump() now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - count = self.store.get_registered_reserved_users_count() - self.assertEquals(self.get_success(count), len(threepids)) + + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEquals(len(users), len(threepids)) def test_support_user_not_add_to_mau_limits(self): support_user_id = "@support:test" @@ -232,11 +270,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.pump() self.assertEqual(self.get_success(count), 0) - self.store.register( - user_id=support_user_id, - token="123", - password_hash=None, - user_type=UserTypes.SUPPORT, + self.store.register_user( + user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) self.store.upsert_monthly_active_user(support_user_id) @@ -268,3 +303,45 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.pump() self.store.upsert_monthly_active_user.assert_not_called() + + def test_get_monthly_active_count_by_service(self): + appservice1_user1 = "@appservice1_user1:example.com" + appservice1_user2 = "@appservice1_user2:example.com" + + appservice2_user1 = "@appservice2_user1:example.com" + native_user1 = "@native_user1:example.com" + + service1 = "service1" + service2 = "service2" + native = "native" + + self.store.register_user( + user_id=appservice1_user1, password_hash=None, appservice_id=service1 + ) + self.store.register_user( + user_id=appservice1_user2, password_hash=None, appservice_id=service1 + ) + self.store.register_user( + user_id=appservice2_user1, password_hash=None, appservice_id=service2 + ) + self.store.register_user(user_id=native_user1, password_hash=None) + self.pump() + + count = self.store.get_monthly_active_count_by_service() + self.assertEqual({}, self.get_success(count)) + + self.store.upsert_monthly_active_user(native_user1) + self.store.upsert_monthly_active_user(appservice1_user1) + self.store.upsert_monthly_active_user(appservice1_user2) + self.store.upsert_monthly_active_user(appservice2_user1) + self.pump() + + count = self.store.get_monthly_active_count() + self.assertEqual(4, self.get_success(count)) + + count = self.store.get_monthly_active_count_by_service() + result = self.get_success(count) + + self.assertEqual(2, result[service1]) + self.assertEqual(1, result[service2]) + self.assertEqual(1, result[native]) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index c125a0d797..7458a37e54 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py
@@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.profile import ProfileStore from synapse.types import UserID from tests import unittest @@ -28,25 +27,22 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(hs.get_db_conn(), hs) + self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") @defer.inlineCallbacks def test_displayname(self): - yield self.store.set_profile_displayname( - self.u_frank.localpart, "Frank", 1, - ) + yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1) self.assertEquals( - "Frank", - (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) ) @defer.inlineCallbacks def test_avatar_url(self): yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here", 1, + self.u_frank.localpart, "http://my.site/here", 1 ) self.assertEquals( diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index f671599cb8..b9fafaa1a6 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py
@@ -40,23 +40,24 @@ class PurgeTests(HomeserverTestCase): third = self.helper.send(self.room_id, body="test3") last = self.helper.send(self.room_id, body="test4") - storage = self.hs.get_datastore() + store = self.hs.get_datastore() + storage = self.hs.get_storage() # Get the topological token - event = storage.get_topological_token_for_event(last["event_id"]) + event = store.get_topological_token_for_event(last["event_id"]) self.pump() event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = storage.purge_events.purge_history(self.room_id, event, True) self.pump() self.assertEqual(self.successResultOf(purge), None) # Try and get the events - get_first = storage.get_event(first["event_id"]) - get_second = storage.get_event(second["event_id"]) - get_third = storage.get_event(third["event_id"]) - get_last = storage.get_event(last["event_id"]) + get_first = store.get_event(first["event_id"]) + get_second = store.get_event(second["event_id"]) + get_third = store.get_event(third["event_id"]) + get_last = store.get_event(last["event_id"]) self.pump() # 1-3 should fail and last will succeed, meaning that 1-3 are deleted diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 4823d44dec..db3667dc43 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +17,8 @@ from mock import Mock +from canonicaljson import json + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership @@ -23,17 +26,20 @@ from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID from tests import unittest -from tests.utils import create_room, setup_test_homeserver +from tests.utils import create_room -class RedactionTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver( - self.addCleanup, resource_for_federation=Mock(), http_client=None +class RedactionTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["redaction_retention_period"] = "30d" + return self.setup_test_homeserver( + resource_for_federation=Mock(), http_client=None, config=config ) + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -42,11 +48,12 @@ class RedactionTestCase(unittest.TestCase): self.room1 = RoomID.from_string("!abc123:test") - yield create_room(hs, self.room1.to_string(), self.u_alice.to_string()) + self.get_success( + create_room(hs, self.room1.to_string(), self.u_alice.to_string()) + ) self.depth = 1 - @defer.inlineCallbacks def inject_room_member( self, room, user, membership, replaces_state=None, extra_content={} ): @@ -63,15 +70,14 @@ class RedactionTestCase(unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.store.persist_event(event, context) + self.get_success(self.storage.persistence.persist_event(event, context)) - defer.returnValue(event) + return event - @defer.inlineCallbacks def inject_message(self, room, user, body): self.depth += 1 @@ -82,19 +88,18 @@ class RedactionTestCase(unittest.TestCase): "sender": user.to_string(), "state_key": user.to_string(), "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, + "content": {"body": body, "msgtype": "message"}, }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.store.persist_event(event, context) + self.get_success(self.storage.persistence.persist_event(event, context)) - defer.returnValue(event) + return event - @defer.inlineCallbacks def inject_redaction(self, room, event_id, user, reason): builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -108,20 +113,23 @@ class RedactionTestCase(unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) ) - yield self.store.persist_event(event, context) + self.get_success(self.storage.persistence.persist_event(event, context)) + + return event - @defer.inlineCallbacks def test_redact(self): - yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) - msg_event = yield self.inject_message(self.room1, self.u_alice, u"t") + msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) # Check event has not been redacted: - event = yield self.store.get_event(msg_event.event_id) + event = self.get_success(self.store.get_event(msg_event.event_id)) self.assertObjectHasAttributes( { @@ -136,11 +144,11 @@ class RedactionTestCase(unittest.TestCase): # Redact event reason = "Because I said so" - yield self.inject_redaction( - self.room1, msg_event.event_id, self.u_alice, reason + self.get_success( + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) ) - event = yield self.store.get_event(msg_event.event_id) + event = self.get_success(self.store.get_event(msg_event.event_id)) self.assertEqual(msg_event.event_id, event.event_id) @@ -164,15 +172,18 @@ class RedactionTestCase(unittest.TestCase): event.unsigned["redacted_because"], ) - @defer.inlineCallbacks def test_redact_join(self): - yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) - msg_event = yield self.inject_room_member( - self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"} + msg_event = self.get_success( + self.inject_room_member( + self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"} + ) ) - event = yield self.store.get_event(msg_event.event_id) + event = self.get_success(self.store.get_event(msg_event.event_id)) self.assertObjectHasAttributes( { @@ -187,13 +198,13 @@ class RedactionTestCase(unittest.TestCase): # Redact event reason = "Because I said so" - yield self.inject_redaction( - self.room1, msg_event.event_id, self.u_alice, reason + self.get_success( + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) ) # Check redaction - event = yield self.store.get_event(msg_event.event_id) + event = self.get_success(self.store.get_event(msg_event.event_id)) self.assertTrue("redacted_because" in event.unsigned) @@ -214,3 +225,214 @@ class RedactionTestCase(unittest.TestCase): }, event.unsigned["redacted_because"], ) + + def test_circular_redaction(self): + redaction_event_id1 = "$redaction1_id:test" + redaction_event_id2 = "$redaction2_id:test" + + class EventIdManglingBuilder: + def __init__(self, base_builder, event_id): + self._base_builder = base_builder + self._event_id = event_id + + @defer.inlineCallbacks + def build(self, prev_event_ids): + built_event = yield self._base_builder.build(prev_event_ids) + + built_event._event_id = self._event_id + built_event._dict["event_id"] = self._event_id + assert built_event.event_id == self._event_id + + return built_event + + @property + def room_id(self): + return self._base_builder.room_id + + event_1, context_1 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id2, + }, + ), + redaction_event_id1, + ) + ) + ) + + self.get_success(self.storage.persistence.persist_event(event_1, context_1)) + + event_2, context_2 = self.get_success( + self.event_creation_handler.create_new_client_event( + EventIdManglingBuilder( + self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "test"}, + "redacts": redaction_event_id1, + }, + ), + redaction_event_id2, + ) + ) + ) + self.get_success(self.storage.persistence.persist_event(event_2, context_2)) + + # fetch one of the redactions + fetched = self.get_success(self.store.get_event(redaction_event_id1)) + + # it should have been redacted + self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2) + self.assertEqual( + fetched.unsigned["redacted_because"].event_id, redaction_event_id2 + ) + + def test_redact_censor(self): + """Test that a redacted event gets censored in the DB after a month + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + + # Check event has not been redacted: + event = self.get_success(self.store.get_event(msg_event.event_id)) + + self.assertObjectHasAttributes( + { + "type": EventTypes.Message, + "user_id": self.u_alice.to_string(), + "content": {"body": "t", "msgtype": "message"}, + }, + event, + ) + + self.assertFalse("redacted_because" in event.unsigned) + + # Redact event + reason = "Because I said so" + self.get_success( + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) + ) + + event = self.get_success(self.store.get_event(msg_event.event_id)) + + self.assertTrue("redacted_because" in event.unsigned) + + self.assertObjectHasAttributes( + { + "type": EventTypes.Message, + "user_id": self.u_alice.to_string(), + "content": {}, + }, + event, + ) + + event_json = self.get_success( + self.store.db.simple_select_one_onecol( + table="event_json", + keyvalues={"event_id": msg_event.event_id}, + retcol="json", + ) + ) + + self.assert_dict( + {"content": {"body": "t", "msgtype": "message"}}, json.loads(event_json) + ) + + # Advance by 30 days, then advance again to ensure that the looping call + # for updating the stream position gets called and then the looping call + # for the censoring gets called. + self.reactor.advance(60 * 60 * 24 * 31) + self.reactor.advance(60 * 60 * 2) + + event_json = self.get_success( + self.store.db.simple_select_one_onecol( + table="event_json", + keyvalues={"event_id": msg_event.event_id}, + retcol="json", + ) + ) + + self.assert_dict({"content": {}}, json.loads(event_json)) + + def test_redact_redaction(self): + """Tests that we can redact a redaction and can fetch it again. + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + + first_redact_event = self.get_success( + self.inject_redaction( + self.room1, msg_event.event_id, self.u_alice, "Redacting message" + ) + ) + + self.get_success( + self.inject_redaction( + self.room1, + first_redact_event.event_id, + self.u_alice, + "Redacting redaction", + ) + ) + + # Now lets jump to the future where we have censored the redaction event + # in the DB. + self.reactor.advance(60 * 60 * 24 * 31) + + # We just want to check that fetching the event doesn't raise an exception. + self.get_success( + self.store.get_event(first_redact_event.event_id, allow_none=True) + ) + + def test_store_redacted_redaction(self): + """Tests that we can store a redacted redaction. + """ + + self.get_success( + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) + ) + + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Redaction, + "sender": self.u_alice.to_string(), + "room_id": self.room1.to_string(), + "content": {"reason": "foo"}, + }, + ) + + redaction_event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + + self.get_success( + self.storage.persistence.persist_event(redaction_event, context) + ) + + # Now lets jump to the future where we have censored the redaction event + # in the DB. + self.reactor.advance(60 * 60 * 24 * 31) + + # We just want to check that fetching the event doesn't raise an exception. + self.get_success( + self.store.get_event(redaction_event.event_id, allow_none=True) + ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index c0e0155bb4..71a40a0a49 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() @@ -37,33 +36,30 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_register(self): - yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.register_user(self.user_id, self.pwhash) self.assertEquals( { # TODO(paul): Surely this field should be 'user_id', not 'name' "name": self.user_id, "password_hash": self.pwhash, + "admin": 0, "is_guest": 0, "consent_version": None, "consent_server_notice_sent": None, "appservice_id": None, "creation_ts": 1000, + "user_type": None, + "deactivated": 0, }, (yield self.store.get_user_by_id(self.user_id)), ) - result = yield self.store.get_user_by_access_token(self.tokens[0]) - - self.assertDictContainsSubset({"name": self.user_id}, result) - - self.assertTrue("token_id" in result) - @defer.inlineCallbacks def test_add_tokens(self): - yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.register_user(self.user_id, self.pwhash) yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) result = yield self.store.get_user_by_access_token(self.tokens[1]) @@ -77,9 +73,12 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_user_delete_access_tokens(self): # add some tokens - yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.register_user(self.user_id, self.pwhash) yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id + self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + ) + yield self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) # now delete some @@ -108,24 +107,12 @@ class RegistrationStoreTestCase(unittest.TestCase): res = yield self.store.is_support_user(None) self.assertFalse(res) - yield self.store.register(user_id=TEST_USER, token="123", password_hash=None) + yield self.store.register_user(user_id=TEST_USER, password_hash=None) res = yield self.store.is_support_user(TEST_USER) self.assertFalse(res) - yield self.store.register( - user_id=SUPPORT_USER, - token="456", - password_hash=None, - user_type=UserTypes.SUPPORT, + yield self.store.register_user( + user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) - - -class TokenGenerator: - def __init__(self): - self._last_issued_token = 0 - - def generate(self, user_id): - self._last_issued_token += 1 - return u"%s-%d" % (user_id, self._last_issued_token) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index a1ea23b068..086adeb8fd 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes +from synapse.api.room_versions import RoomVersions from synapse.types import RoomAlias, RoomID, UserID from tests import unittest @@ -40,6 +41,7 @@ class RoomStoreTestCase(unittest.TestCase): self.room.to_string(), room_creator_user_id=self.u_creator.to_string(), is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks @@ -62,23 +64,27 @@ class RoomEventsStoreTestCase(unittest.TestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") yield self.store.store_room( - self.room.to_string(), room_creator_user_id="@creator:text", is_public=True + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.store.persist_event( + yield self.storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) @defer.inlineCallbacks def STALE_test_room_name(self): - name = u"A-Room-Name" + name = "A-Room-Name" yield self.inject_room_event( etype=EventTypes.Name, name=name, content={"name": name}, depth=1 @@ -94,7 +100,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def STALE_test_room_topic(self): - topic = u"A place for things" + topic = "A place for things" yield self.inject_room_event( etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 73ed943f5a..00df0ea68e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,74 +14,145 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import Mock -from mock import Mock +from synapse.api.constants import Membership +from synapse.rest.admin import register_servlets_for_client_rest_resource +from synapse.rest.client.v1 import login, room +from synapse.types import Requester, UserID -from twisted.internet import defer +from tests import unittest -from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import RoomVersions -from synapse.types import RoomID, UserID -from tests import unittest -from tests.utils import create_room, setup_test_homeserver +class RoomMemberStoreTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + register_servlets_for_client_rest_resource, + room.register_servlets, + ] -class RoomMemberStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver( - self.addCleanup, resource_for_federation=Mock(), http_client=None + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver( + resource_for_federation=Mock(), http_client=None ) + return hs + + def prepare(self, reactor, clock, hs): + # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_creation_handler = hs.get_event_creation_handler() - self.u_alice = UserID.from_string("@alice:test") - self.u_bob = UserID.from_string("@bob:test") + self.u_alice = self.register_user("alice", "pass") + self.t_alice = self.login("alice", "pass") + self.u_bob = self.register_user("bob", "pass") # User elsewhere on another host self.u_charlie = UserID.from_string("@charlie:elsewhere") - self.room = RoomID.from_string("!abc123:test") - - yield create_room(hs, self.room.to_string(), self.u_alice.to_string()) - - @defer.inlineCallbacks - def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }, - ) + def test_one_member(self): + + # Alice creates the room, and is automatically joined + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + rooms_for_user = self.get_success( + self.store.get_rooms_for_local_user_where_membership_is( + self.u_alice, [Membership.JOIN] + ) ) - yield self.store.persist_event(event, context) + self.assertEquals([self.room], [m.room_id for m in rooms_for_user]) + + def test_count_known_servers(self): + """ + _count_known_servers will calculate how many servers are in a room. + """ + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.inject_room_member(self.room, self.u_bob, Membership.JOIN) + self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) + + servers = self.get_success(self.store._count_known_servers()) + self.assertEqual(servers, 2) + + def test_count_known_servers_stat_counter_disabled(self): + """ + If enabled, the metrics for how many servers are known will be counted. + """ + self.assertTrue("_known_servers_count" not in self.store.__dict__.keys()) + + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.inject_room_member(self.room, self.u_bob, Membership.JOIN) + self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) + + self.pump(20) + + self.assertTrue("_known_servers_count" not in self.store.__dict__.keys()) + + @unittest.override_config( + {"enable_metrics": True, "metrics_flags": {"known_servers": True}} + ) + def test_count_known_servers_stat_counter_enabled(self): + """ + If enabled, the metrics for how many servers are known will be counted. + """ + # Initialises to 1 -- itself + self.assertEqual(self.store._known_servers_count, 1) + + self.pump(20) + + # No rooms have been joined, so technically the SQL returns 0, but it + # will still say it knows about itself. + self.assertEqual(self.store._known_servers_count, 1) + + self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) + self.inject_room_member(self.room, self.u_bob, Membership.JOIN) + self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) + + self.pump(20) + + # It now knows about Charlie's server. + self.assertEqual(self.store._known_servers_count, 2) + + +class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() + self.room_creator = homeserver.get_room_creation_handler() + + def test_can_rerun_update(self): + # First make sure we have completed all updates. + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) + + # Now let's create a room, which will insert a membership + user = UserID("alice", "test") + requester = Requester(user, None, False, None, None) + self.get_success(self.room_creator.create_room(requester, {})) + + # Register the background update to run again. + self.get_success( + self.store.db.simple_insert( + table="background_updates", + values={ + "update_name": "current_state_events_membership", + "progress_json": "{}", + "depends_on": None, + }, + ) + ) - defer.returnValue(event) + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db.updates._all_done = False - @defer.inlineCallbacks - def test_one_member(self): - yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) - - self.assertEquals( - [self.room.to_string()], - [ - m.room_id - for m in ( - yield self.store.get_rooms_for_user_where_membership_is( - self.u_alice.to_string(), [Membership.JOIN] - ) - ) - ], - ) + # Now let's actually drive the updates to completion + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index b6169436de..0b88308ff4 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py
@@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -43,7 +45,10 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") yield self.store.store_room( - self.room.to_string(), room_creator_user_id="@creator:text", is_public=True + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, ) @defer.inlineCallbacks @@ -63,9 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase): builder ) - yield self.store.persist_event(event, context) + yield self.storage.persistence.persist_event(event, context) - defer.returnValue(event) + return event def assertStateMapEqual(self, s1, s2): for t in s1: @@ -76,32 +81,34 @@ class StateStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_state_groups_ids(self): e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids( + state_group_map = yield self.storage.state.get_state_groups_ids( self.room, [e2.event_id] ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] self.assertDictEqual( state_map, - {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id}, + {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, ) @defer.inlineCallbacks def test_get_state_groups(self): e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) + state_group_map = yield self.storage.state.get_state_groups( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -113,10 +120,10 @@ class StateStoreTestCase(tests.unittest.TestCase): # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, '', {} + self.room, self.u_alice, EventTypes.Create, "", {} ) e2 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) e3 = yield self.inject_state_event( self.room, @@ -141,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event(e5.event_id) + state = yield self.storage.state.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -157,21 +164,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.store.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) + state = yield self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -181,7 +188,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, @@ -199,7 +206,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.store.get_state_for_event( + state = yield self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -215,13 +222,18 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) + group_ids = yield self.storage.state.get_state_groups_ids( + room_id, [e5.event_id] + ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -237,8 +249,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -250,8 +265,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -267,8 +285,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -287,8 +308,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -304,8 +328,11 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -317,8 +344,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -331,9 +361,11 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### # deliberately remove e2 (room name) from the _state_group_cache - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, True) self.assertEqual(known_absent, set()) @@ -346,21 +378,23 @@ class StateStoreTestCase(tests.unittest.TestCase): ) state_dict_ids.pop((e2.type, e2.state_key)) - self.store._state_group_cache.invalidate(group) - self.store._state_group_cache.update( - sequence=self.store._state_group_cache.sequence, + self.state_datastore._state_group_cache.invalidate(group) + self.state_datastore._state_group_cache.update( + sequence=self.state_datastore._state_group_cache.sequence, key=group, value=state_dict_ids, # list fetched keys so it knows it's partial fetched_keys=((e1.type, e1.state_key),), ) - (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( - group - ) + ( + is_all, + known_absent, + state_dict_ids, + ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) - self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) + self.assertEqual(known_absent, {(e1.type, e1.state_key)}) self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ @@ -369,8 +403,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -381,8 +418,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: set()}, include_others=True @@ -394,8 +434,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -405,8 +448,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: None}, include_others=True @@ -424,8 +470,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -435,8 +484,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=True @@ -448,8 +500,11 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False @@ -459,8 +514,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_members_cache, + ( + state_dict, + is_all, + ) = yield self.state_datastore._get_state_for_group_using_cache( + self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( types={EventTypes.Member: {e5.state_key}}, include_others=False diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index 14169afa96..8e817e2c7f 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py
@@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.util.retryutils import MAX_RETRY_INTERVAL + from tests.unittest import HomeserverTestCase @@ -29,17 +31,28 @@ class TransactionStoreTestCase(HomeserverTestCase): r = self.get_success(d) self.assertIsNone(r) - d = self.store.set_destination_retry_timings("example.com", 50, 100) + d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) self.get_success(d) d = self.store.get_destination_retry_timings("example.com") r = self.get_success(d) - self.assert_dict({"retry_last_ts": 50, "retry_interval": 100}, r) + self.assert_dict( + {"retry_last_ts": 50, "retry_interval": 100, "failure_ts": 1000}, r + ) def test_initial_set_transactions(self): """Tests that we can successfully set the destination retries (there was a bug around invalidating the cache that broke this) """ - d = self.store.set_destination_retry_timings("example.com", 50, 100) + d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) + self.get_success(d) + + def test_large_destination_retry(self): + d = self.store.set_destination_retry_timings( + "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL + ) + self.get_success(d) + + d = self.store.get_destination_retry_timings("example.com") self.get_success(d) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index d7d244ce97..6a545d2eb0 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py
@@ -15,8 +15,6 @@ from twisted.internet import defer -from synapse.storage import UserDirectoryStore - from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) + self.store = self.hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice.