summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2023-11-09 11:14:57 -0500
committerPatrick Cloke <patrickc@matrix.org>2023-11-09 11:14:57 -0500
commit8c2d3d0b4cf674ec9b8fa582d50e66cd0960a73b (patch)
tree115f6d18fb7dbbfc710f40a68d7c4d0007be056c /tests/storage
parentMerge remote-tracking branch 'origin/develop' into matrix-org-hotfixes (diff)
parentConvert simple_select_one_txn and simple_select_one to return tuples. (#16612) (diff)
downloadsynapse-8c2d3d0b4cf674ec9b8fa582d50e66cd0960a73b.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/databases/main/test_cache.py117
-rw-r--r--tests/storage/test_base.py4
-rw-r--r--tests/storage/test_room.py13
3 files changed, 122 insertions, 12 deletions
diff --git a/tests/storage/databases/main/test_cache.py b/tests/storage/databases/main/test_cache.py
new file mode 100644

index 0000000000..3f71f5d102 --- /dev/null +++ b/tests/storage/databases/main/test_cache.py
@@ -0,0 +1,117 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock, call + +from synapse.storage.database import LoggingTransaction + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.unittest import HomeserverTestCase + + +class CacheInvalidationTestCase(HomeserverTestCase): + def setUp(self) -> None: + super().setUp() + self.store = self.hs.get_datastores().main + + def test_bulk_invalidation(self) -> None: + master_invalidate = Mock() + + self.store._get_cached_user_device.invalidate = master_invalidate + + keys_to_invalidate = [ + ("a", "b"), + ("c", "d"), + ("e", "f"), + ("g", "h"), + ] + + def test_txn(txn: LoggingTransaction) -> None: + self.store._invalidate_cache_and_stream_bulk( + txn, + # This is an arbitrarily chosen cached store function. It was chosen + # because it takes more than one argument. We'll use this later to + # check that the invalidation was actioned over replication. + cache_func=self.store._get_cached_user_device, + key_tuples=keys_to_invalidate, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "test_invalidate_cache_and_stream_bulk", test_txn + ) + ) + + master_invalidate.assert_has_calls( + [call(key_list) for key_list in keys_to_invalidate], + any_order=True, + ) + + +class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase): + def setUp(self) -> None: + super().setUp() + self.store = self.hs.get_datastores().main + + def test_bulk_invalidation_replicates(self) -> None: + """Like test_bulk_invalidation, but also checks the invalidations replicate.""" + master_invalidate = Mock() + worker_invalidate = Mock() + + self.store._get_cached_user_device.invalidate = master_invalidate + worker = self.make_worker_hs("synapse.app.generic_worker") + worker_ds = worker.get_datastores().main + worker_ds._get_cached_user_device.invalidate = worker_invalidate + + keys_to_invalidate = [ + ("a", "b"), + ("c", "d"), + ("e", "f"), + ("g", "h"), + ] + + def test_txn(txn: LoggingTransaction) -> None: + self.store._invalidate_cache_and_stream_bulk( + txn, + # This is an arbitrarily chosen cached store function. It was chosen + # because it takes more than one argument. We'll use this later to + # check that the invalidation was actioned over replication. + cache_func=self.store._get_cached_user_device, + key_tuples=keys_to_invalidate, + ) + + assert self.store._cache_id_gen is not None + initial_token = self.store._cache_id_gen.get_current_token() + self.get_success( + self.database_pool.runInteraction( + "test_invalidate_cache_and_stream_bulk", test_txn + ) + ) + second_token = self.store._cache_id_gen.get_current_token() + + self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate)) + + self.get_success( + worker.get_replication_data_handler().wait_for_stream_position( + "master", "caches", second_token + ) + ) + + master_invalidate.assert_has_calls( + [call(key_list) for key_list in keys_to_invalidate], + any_order=True, + ) + worker_invalidate.assert_has_calls( + [call(key_list) for key_list in keys_to_invalidate], + any_order=True, + ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index f34b6b2dcf..491e6d5e63 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py
@@ -222,7 +222,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret) + self.assertEqual((1, 2, 3), ret) self.mock_txn.execute.assert_called_once_with( "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"] ) @@ -243,7 +243,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): ) ) - self.assertFalse(ret) + self.assertIsNone(ret) @defer.inlineCallbacks def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]: diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index ce34195a25..d3ffe963d3 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -42,16 +42,9 @@ class RoomStoreTestCase(HomeserverTestCase): ) def test_get_room(self) -> None: - res = self.get_success(self.store.get_room(self.room.to_string())) - assert res is not None - self.assertLessEqual( - { - "room_id": self.room.to_string(), - "creator": self.u_creator.to_string(), - "is_public": True, - }.items(), - res.items(), - ) + room = self.get_success(self.store.get_room(self.room.to_string())) + assert room is not None + self.assertTrue(room[0]) def test_get_room_unknown_room(self) -> None: self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))