summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_stats.py4
-rw-r--r--tests/handlers/test_user_directory.py4
-rw-r--r--tests/media/test_media_storage.py2
-rw-r--r--tests/rest/admin/test_media.py16
-rw-r--r--tests/rest/admin/test_user.py2
-rw-r--r--tests/rest/client/test_account.py8
-rw-r--r--tests/rest/client/test_register.py12
-rw-r--r--tests/rest/media/test_media_retention.py20
-rw-r--r--tests/storage/databases/main/test_cache.py117
-rw-r--r--tests/storage/test_base.py4
-rw-r--r--tests/storage/test_room.py13
-rw-r--r--tests/utils.py2
12 files changed, 158 insertions, 46 deletions
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 76c56d5434..15e19b15fb 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -84,7 +84,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
 
-        return self.get_success(
+        row = self.get_success(
             self.store.db_pool.simple_select_one(
                 table + "_current",
                 {id_col: stat_id},
@@ -93,6 +93,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
 
+        return None if row is None else dict(zip(cols, row))
+
     def _perform_background_initial_update(self) -> None:
         # Do the initial population of the stats via the background update
         self._add_background_updates()
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index b5f15aa7d4..388447eea6 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -366,7 +366,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
         profile = self.get_success(self.store._get_user_in_directory(regular_user_id))
         assert profile is not None
-        self.assertTrue(profile["display_name"] == display_name)
+        self.assertTrue(profile[0] == display_name)
 
     def test_handle_local_profile_change_with_deactivated_user(self) -> None:
         # create user
@@ -385,7 +385,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         # profile is in directory
         profile = self.get_success(self.store._get_user_in_directory(r_user_id))
         assert profile is not None
-        self.assertTrue(profile["display_name"] == display_name)
+        self.assertEqual(profile[0], display_name)
 
         # deactivate user
         self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index 15f5d644e4..a8e7a76b29 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -504,7 +504,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         origin, media_id = self.media_id.split("/")
         info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
         assert info is not None
-        file_id = info["filesystem_id"]
+        file_id = info.filesystem_id
 
         thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
             origin, file_id
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 278808abb5..dac79bd745 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["quarantined_by"])
+        self.assertFalse(media_info.quarantined_by)
 
         # quarantining
         channel = self.make_request(
@@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertTrue(media_info["quarantined_by"])
+        self.assertTrue(media_info.quarantined_by)
 
         # remove from quarantine
         channel = self.make_request(
@@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["quarantined_by"])
+        self.assertFalse(media_info.quarantined_by)
 
     def test_quarantine_protected_media(self) -> None:
         """
@@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
         # verify protection
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertTrue(media_info["safe_from_quarantine"])
+        self.assertTrue(media_info.safe_from_quarantine)
 
         # quarantining
         channel = self.make_request(
@@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
         # verify that is not in quarantine
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["quarantined_by"])
+        self.assertFalse(media_info.quarantined_by)
 
 
 class ProtectMediaByIDTestCase(_AdminMediaTests):
@@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["safe_from_quarantine"])
+        self.assertFalse(media_info.safe_from_quarantine)
 
         # protect
         channel = self.make_request(
@@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertTrue(media_info["safe_from_quarantine"])
+        self.assertTrue(media_info.safe_from_quarantine)
 
         # unprotect
         channel = self.make_request(
@@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
         assert media_info is not None
-        self.assertFalse(media_info["safe_from_quarantine"])
+        self.assertFalse(media_info.safe_from_quarantine)
 
 
 class PurgeMediaCacheTestCase(_AdminMediaTests):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 37f37a09d8..42b065d883 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # is in user directory
         profile = self.get_success(self.store._get_user_in_directory(self.other_user))
         assert profile is not None
-        self.assertTrue(profile["display_name"] == "User")
+        self.assertEqual(profile[0], "User")
 
         # Deactivate user
         channel = self.make_request(
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index cffbda9a7d..bd59bb50cf 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         #
         # Note that we don't have the UI Auth session ID, so just pull out the single
         # row.
-        ui_auth_data = self.get_success(
-            self.store.db_pool.simple_select_one(
-                "ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
+        result = self.get_success(
+            self.store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions", keyvalues={}, retcol="clientdict"
             )
         )
-        client_dict = db_to_json(ui_auth_data["clientdict"])
+        client_dict = db_to_json(result)
         self.assertNotIn("new_password", client_dict)
 
     @override_config({"rc_3pid_validation": {"burst_count": 3}})
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index ba4e017a0e..b04094b7b3 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertLessEqual(det_data.items(), channel.json_body.items())
 
         # Check the `completed` counter has been incremented and pending is 0
-        res = self.get_success(
+        pending, completed = self.get_success(
             store.db_pool.simple_select_one(
                 "registration_tokens",
                 keyvalues={"token": token},
                 retcols=["pending", "completed"],
             )
         )
-        self.assertEqual(res["completed"], 1)
-        self.assertEqual(res["pending"], 0)
+        self.assertEqual(completed, 1)
+        self.assertEqual(pending, 0)
 
     @override_config({"registration_requires_token": True})
     def test_POST_registration_token_invalid(self) -> None:
@@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         params1["auth"]["type"] = LoginType.DUMMY
         self.make_request(b"POST", self.url, params1)
         # Check pending=0 and completed=1
-        res = self.get_success(
+        pending, completed = self.get_success(
             store.db_pool.simple_select_one(
                 "registration_tokens",
                 keyvalues={"token": token},
                 retcols=["pending", "completed"],
             )
         )
-        self.assertEqual(res["pending"], 0)
-        self.assertEqual(res["completed"], 1)
+        self.assertEqual(pending, 0)
+        self.assertEqual(completed, 1)
 
         # Check auth still fails when using token with session2
         channel = self.make_request(b"POST", self.url, params2)
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index b59d9dfd4d..27a663a23b 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
         def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
             """Given an MXC URI, assert whether it has been purged or not."""
             if mxc_uri.server_name == self.hs.config.server.server_name:
-                found_media_dict = self.get_success(
-                    self.store.get_local_media(mxc_uri.media_id)
+                found_media = bool(
+                    self.get_success(self.store.get_local_media(mxc_uri.media_id))
                 )
             else:
-                found_media_dict = self.get_success(
-                    self.store.get_cached_remote_media(
-                        mxc_uri.server_name, mxc_uri.media_id
+                found_media = bool(
+                    self.get_success(
+                        self.store.get_cached_remote_media(
+                            mxc_uri.server_name, mxc_uri.media_id
+                        )
                     )
                 )
 
             if expect_purged:
-                self.assertIsNone(
-                    found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
-                )
+                self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
             else:
-                self.assertIsNotNone(
-                    found_media_dict,
+                self.assertTrue(
+                    found_media,
                     msg=f"{mxc_uri} unexpectedly purged",
                 )
 
diff --git a/tests/storage/databases/main/test_cache.py b/tests/storage/databases/main/test_cache.py
new file mode 100644
index 0000000000..3f71f5d102
--- /dev/null
+++ b/tests/storage/databases/main/test_cache.py
@@ -0,0 +1,117 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import Mock, call
+
+from synapse.storage.database import LoggingTransaction
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import HomeserverTestCase
+
+
+class CacheInvalidationTestCase(HomeserverTestCase):
+    def setUp(self) -> None:
+        super().setUp()
+        self.store = self.hs.get_datastores().main
+
+    def test_bulk_invalidation(self) -> None:
+        master_invalidate = Mock()
+
+        self.store._get_cached_user_device.invalidate = master_invalidate
+
+        keys_to_invalidate = [
+            ("a", "b"),
+            ("c", "d"),
+            ("e", "f"),
+            ("g", "h"),
+        ]
+
+        def test_txn(txn: LoggingTransaction) -> None:
+            self.store._invalidate_cache_and_stream_bulk(
+                txn,
+                # This is an arbitrarily chosen cached store function. It was chosen
+                # because it takes more than one argument. We'll use this later to
+                # check that the invalidation was actioned over replication.
+                cache_func=self.store._get_cached_user_device,
+                key_tuples=keys_to_invalidate,
+            )
+
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test_invalidate_cache_and_stream_bulk", test_txn
+            )
+        )
+
+        master_invalidate.assert_has_calls(
+            [call(key_list) for key_list in keys_to_invalidate],
+            any_order=True,
+        )
+
+
+class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
+    def setUp(self) -> None:
+        super().setUp()
+        self.store = self.hs.get_datastores().main
+
+    def test_bulk_invalidation_replicates(self) -> None:
+        """Like test_bulk_invalidation, but also checks the invalidations replicate."""
+        master_invalidate = Mock()
+        worker_invalidate = Mock()
+
+        self.store._get_cached_user_device.invalidate = master_invalidate
+        worker = self.make_worker_hs("synapse.app.generic_worker")
+        worker_ds = worker.get_datastores().main
+        worker_ds._get_cached_user_device.invalidate = worker_invalidate
+
+        keys_to_invalidate = [
+            ("a", "b"),
+            ("c", "d"),
+            ("e", "f"),
+            ("g", "h"),
+        ]
+
+        def test_txn(txn: LoggingTransaction) -> None:
+            self.store._invalidate_cache_and_stream_bulk(
+                txn,
+                # This is an arbitrarily chosen cached store function. It was chosen
+                # because it takes more than one argument. We'll use this later to
+                # check that the invalidation was actioned over replication.
+                cache_func=self.store._get_cached_user_device,
+                key_tuples=keys_to_invalidate,
+            )
+
+        assert self.store._cache_id_gen is not None
+        initial_token = self.store._cache_id_gen.get_current_token()
+        self.get_success(
+            self.database_pool.runInteraction(
+                "test_invalidate_cache_and_stream_bulk", test_txn
+            )
+        )
+        second_token = self.store._cache_id_gen.get_current_token()
+
+        self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))
+
+        self.get_success(
+            worker.get_replication_data_handler().wait_for_stream_position(
+                "master", "caches", second_token
+            )
+        )
+
+        master_invalidate.assert_has_calls(
+            [call(key_list) for key_list in keys_to_invalidate],
+            any_order=True,
+        )
+        worker_invalidate.assert_has_calls(
+            [call(key_list) for key_list in keys_to_invalidate],
+            any_order=True,
+        )
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index f34b6b2dcf..491e6d5e63 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             )
         )
 
-        self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret)
+        self.assertEqual((1, 2, 3), ret)
         self.mock_txn.execute.assert_called_once_with(
             "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
         )
@@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             )
         )
 
-        self.assertFalse(ret)
+        self.assertIsNone(ret)
 
     @defer.inlineCallbacks
     def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index ce34195a25..d3ffe963d3 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase):
         )
 
     def test_get_room(self) -> None:
-        res = self.get_success(self.store.get_room(self.room.to_string()))
-        assert res is not None
-        self.assertLessEqual(
-            {
-                "room_id": self.room.to_string(),
-                "creator": self.u_creator.to_string(),
-                "is_public": True,
-            }.items(),
-            res.items(),
-        )
+        room = self.get_success(self.store.get_room(self.room.to_string()))
+        assert room is not None
+        self.assertTrue(room[0])
 
     def test_get_room_unknown_room(self) -> None:
         self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
diff --git a/tests/utils.py b/tests/utils.py
index 9be02b8ea7..c44e5cb4ee 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -83,11 +83,11 @@ def setupdb() -> None:
 
         # Set up in the db
         db_conn = db_engine.module.connect(
+            dbname=POSTGRES_BASE_DB,
             user=POSTGRES_USER,
             host=POSTGRES_HOST,
             port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
-            dbname=POSTGRES_BASE_DB,
         )
         logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
         prepare_database(logging_conn, db_engine, None)