diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_stats.py | 4 | ||||
-rw-r--r-- | tests/handlers/test_user_directory.py | 4 | ||||
-rw-r--r-- | tests/media/test_media_storage.py | 2 | ||||
-rw-r--r-- | tests/rest/admin/test_media.py | 16 | ||||
-rw-r--r-- | tests/rest/admin/test_user.py | 2 | ||||
-rw-r--r-- | tests/rest/client/test_account.py | 8 | ||||
-rw-r--r-- | tests/rest/client/test_register.py | 12 | ||||
-rw-r--r-- | tests/rest/media/test_media_retention.py | 20 | ||||
-rw-r--r-- | tests/storage/databases/main/test_cache.py | 117 | ||||
-rw-r--r-- | tests/storage/test_base.py | 4 | ||||
-rw-r--r-- | tests/storage/test_room.py | 13 | ||||
-rw-r--r-- | tests/utils.py | 2 |
12 files changed, 158 insertions, 46 deletions
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 76c56d5434..15e19b15fb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) - return self.get_success( + row = self.get_success( self.store.db_pool.simple_select_one( table + "_current", {id_col: stat_id}, @@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) + return None if row is None else dict(zip(cols, row)) + def _perform_background_initial_update(self) -> None: # Do the initial population of the stats via the background update self._add_background_updates() diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index b5f15aa7d4..388447eea6 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) profile = self.get_success(self.store._get_user_in_directory(regular_user_id)) assert profile is not None - self.assertTrue(profile["display_name"] == display_name) + self.assertTrue(profile[0] == display_name) def test_handle_local_profile_change_with_deactivated_user(self) -> None: # create user @@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is in directory profile = self.get_success(self.store._get_user_in_directory(r_user_id)) assert profile is not None - self.assertTrue(profile["display_name"] == display_name) + self.assertEqual(profile[0], display_name) # deactivate user self.get_success(self.store.set_user_deactivated_status(r_user_id, True)) diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 15f5d644e4..a8e7a76b29 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): origin, media_id = self.media_id.split("/") info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) assert info is not None - file_id = info["filesystem_id"] + file_id = info.filesystem_id thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( origin, file_id diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 278808abb5..dac79bd745 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["quarantined_by"]) + self.assertFalse(media_info.quarantined_by) # quarantining channel = self.make_request( @@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertTrue(media_info["quarantined_by"]) + self.assertTrue(media_info.quarantined_by) # remove from quarantine channel = self.make_request( @@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["quarantined_by"]) + self.assertFalse(media_info.quarantined_by) def test_quarantine_protected_media(self) -> None: """ @@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): # verify protection media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertTrue(media_info["safe_from_quarantine"]) + self.assertTrue(media_info.safe_from_quarantine) # quarantining channel = self.make_request( @@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests): # verify that is not in quarantine media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["quarantined_by"]) + self.assertFalse(media_info.quarantined_by) class ProtectMediaByIDTestCase(_AdminMediaTests): @@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["safe_from_quarantine"]) + self.assertFalse(media_info.safe_from_quarantine) # protect channel = self.make_request( @@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertTrue(media_info["safe_from_quarantine"]) + self.assertTrue(media_info.safe_from_quarantine) # unprotect channel = self.make_request( @@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests): media_info = self.get_success(self.store.get_local_media(self.media_id)) assert media_info is not None - self.assertFalse(media_info["safe_from_quarantine"]) + self.assertFalse(media_info.safe_from_quarantine) class PurgeMediaCacheTestCase(_AdminMediaTests): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 37f37a09d8..42b065d883 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # is in user directory profile = self.get_success(self.store._get_user_in_directory(self.other_user)) assert profile is not None - self.assertTrue(profile["display_name"] == "User") + self.assertEqual(profile[0], "User") # Deactivate user channel = self.make_request( diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index cffbda9a7d..bd59bb50cf 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # # Note that we don't have the UI Auth session ID, so just pull out the single # row. - ui_auth_data = self.get_success( - self.store.db_pool.simple_select_one( - "ui_auth_sessions", keyvalues={}, retcols=("clientdict",) + result = self.get_success( + self.store.db_pool.simple_select_one_onecol( + "ui_auth_sessions", keyvalues={}, retcol="clientdict" ) ) - client_dict = db_to_json(ui_auth_data["clientdict"]) + client_dict = db_to_json(result) self.assertNotIn("new_password", client_dict) @override_config({"rc_3pid_validation": {"burst_count": 3}}) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index ba4e017a0e..b04094b7b3 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertLessEqual(det_data.items(), channel.json_body.items()) # Check the `completed` counter has been incremented and pending is 0 - res = self.get_success( + pending, completed = self.get_success( store.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["pending", "completed"], ) ) - self.assertEqual(res["completed"], 1) - self.assertEqual(res["pending"], 0) + self.assertEqual(completed, 1) + self.assertEqual(pending, 0) @override_config({"registration_requires_token": True}) def test_POST_registration_token_invalid(self) -> None: @@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): params1["auth"]["type"] = LoginType.DUMMY self.make_request(b"POST", self.url, params1) # Check pending=0 and completed=1 - res = self.get_success( + pending, completed = self.get_success( store.db_pool.simple_select_one( "registration_tokens", keyvalues={"token": token}, retcols=["pending", "completed"], ) ) - self.assertEqual(res["pending"], 0) - self.assertEqual(res["completed"], 1) + self.assertEqual(pending, 0) + self.assertEqual(completed, 1) # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, params2) diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py index b59d9dfd4d..27a663a23b 100644 --- a/tests/rest/media/test_media_retention.py +++ b/tests/rest/media/test_media_retention.py @@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase): def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None: """Given an MXC URI, assert whether it has been purged or not.""" if mxc_uri.server_name == self.hs.config.server.server_name: - found_media_dict = self.get_success( - self.store.get_local_media(mxc_uri.media_id) + found_media = bool( + self.get_success(self.store.get_local_media(mxc_uri.media_id)) ) else: - found_media_dict = self.get_success( - self.store.get_cached_remote_media( - mxc_uri.server_name, mxc_uri.media_id + found_media = bool( + self.get_success( + self.store.get_cached_remote_media( + mxc_uri.server_name, mxc_uri.media_id + ) ) ) if expect_purged: - self.assertIsNone( - found_media_dict, msg=f"{mxc_uri} unexpectedly not purged" - ) + self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged") else: - self.assertIsNotNone( - found_media_dict, + self.assertTrue( + found_media, msg=f"{mxc_uri} unexpectedly purged", ) diff --git a/tests/storage/databases/main/test_cache.py b/tests/storage/databases/main/test_cache.py new file mode 100644 index 0000000000..3f71f5d102 --- /dev/null +++ b/tests/storage/databases/main/test_cache.py @@ -0,0 +1,117 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock, call + +from synapse.storage.database import LoggingTransaction + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.unittest import HomeserverTestCase + + +class CacheInvalidationTestCase(HomeserverTestCase): + def setUp(self) -> None: + super().setUp() + self.store = self.hs.get_datastores().main + + def test_bulk_invalidation(self) -> None: + master_invalidate = Mock() + + self.store._get_cached_user_device.invalidate = master_invalidate + + keys_to_invalidate = [ + ("a", "b"), + ("c", "d"), + ("e", "f"), + ("g", "h"), + ] + + def test_txn(txn: LoggingTransaction) -> None: + self.store._invalidate_cache_and_stream_bulk( + txn, + # This is an arbitrarily chosen cached store function. It was chosen + # because it takes more than one argument. We'll use this later to + # check that the invalidation was actioned over replication. + cache_func=self.store._get_cached_user_device, + key_tuples=keys_to_invalidate, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "test_invalidate_cache_and_stream_bulk", test_txn + ) + ) + + master_invalidate.assert_has_calls( + [call(key_list) for key_list in keys_to_invalidate], + any_order=True, + ) + + +class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase): + def setUp(self) -> None: + super().setUp() + self.store = self.hs.get_datastores().main + + def test_bulk_invalidation_replicates(self) -> None: + """Like test_bulk_invalidation, but also checks the invalidations replicate.""" + master_invalidate = Mock() + worker_invalidate = Mock() + + self.store._get_cached_user_device.invalidate = master_invalidate + worker = self.make_worker_hs("synapse.app.generic_worker") + worker_ds = worker.get_datastores().main + worker_ds._get_cached_user_device.invalidate = worker_invalidate + + keys_to_invalidate = [ + ("a", "b"), + ("c", "d"), + ("e", "f"), + ("g", "h"), + ] + + def test_txn(txn: LoggingTransaction) -> None: + self.store._invalidate_cache_and_stream_bulk( + txn, + # This is an arbitrarily chosen cached store function. It was chosen + # because it takes more than one argument. We'll use this later to + # check that the invalidation was actioned over replication. + cache_func=self.store._get_cached_user_device, + key_tuples=keys_to_invalidate, + ) + + assert self.store._cache_id_gen is not None + initial_token = self.store._cache_id_gen.get_current_token() + self.get_success( + self.database_pool.runInteraction( + "test_invalidate_cache_and_stream_bulk", test_txn + ) + ) + second_token = self.store._cache_id_gen.get_current_token() + + self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate)) + + self.get_success( + worker.get_replication_data_handler().wait_for_stream_position( + "master", "caches", second_token + ) + ) + + master_invalidate.assert_has_calls( + [call(key_list) for key_list in keys_to_invalidate], + any_order=True, + ) + worker_invalidate.assert_has_calls( + [call(key_list) for key_list in keys_to_invalidate], + any_order=True, + ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index f34b6b2dcf..491e6d5e63 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) + self.assertEqual((1, 2, 3), ret) self.mock_txn.execute.assert_called_once_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertFalse(ret) + self.assertIsNone(ret) @defer.inlineCallbacks def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ce34195a25..d3ffe963d3 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase): ) def test_get_room(self) -> None: - res = self.get_success(self.store.get_room(self.room.to_string())) - assert res is not None - self.assertLessEqual( - { - "room_id": self.room.to_string(), - "creator": self.u_creator.to_string(), - "is_public": True, - }.items(), - res.items(), - ) + room = self.get_success(self.store.get_room(self.room.to_string())) + assert room is not None + self.assertTrue(room[0]) def test_get_room_unknown_room(self) -> None: self.assertIsNone(self.get_success(self.store.get_room("!uknown:test"))) diff --git a/tests/utils.py b/tests/utils.py index 9be02b8ea7..c44e5cb4ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -83,11 +83,11 @@ def setupdb() -> None: # Set up in the db db_conn = db_engine.module.connect( + dbname=POSTGRES_BASE_DB, user=POSTGRES_USER, host=POSTGRES_HOST, port=POSTGRES_PORT, password=POSTGRES_PASSWORD, - dbname=POSTGRES_BASE_DB, ) logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") prepare_database(logging_conn, db_engine, None) |