summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/events/test_snapshot.py4
-rw-r--r--tests/handlers/test_federation.py6
-rw-r--r--tests/handlers/test_federation_event.py9
-rw-r--r--tests/handlers/test_message.py14
-rw-r--r--tests/handlers/test_user_directory.py2
-rw-r--r--tests/replication/slave/storage/_base.py2
-rw-r--r--tests/replication/slave/storage/test_events.py10
-rw-r--r--tests/replication/slave/storage/test_receipts.py12
-rw-r--r--tests/rest/admin/test_user.py4
-rw-r--r--tests/rest/client/test_retention.py4
-rw-r--r--tests/rest/client/test_room_batch.py6
-rw-r--r--tests/storage/test_event_chain.py3
-rw-r--r--tests/storage/test_events.py12
-rw-r--r--tests/storage/test_purge.py14
-rw-r--r--tests/storage/test_redaction.py14
-rw-r--r--tests/storage/test_room.py4
-rw-r--r--tests/storage/test_room_search.py4
-rw-r--r--tests/storage/test_state.py2
-rw-r--r--tests/test_state.py6
-rw-r--r--tests/test_utils/event_injection.py2
-rw-r--r--tests/test_visibility.py46
-rw-r--r--tests/utils.py2
22 files changed, 115 insertions, 67 deletions
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index defbc68c18..8ddce83b83 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
+        self._storage_controllers = hs.get_storage_controllers()
 
         self.user_id = self.register_user("u1", "pass")
         self.user_tok = self.login("u1", "pass")
@@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
     def _check_serialize_deserialize(self, event, context):
         serialized = self.get_success(context.serialize(event, self.store))
 
-        d_context = EventContext.deserialize(self.storage, serialized)
+        d_context = EventContext.deserialize(self._storage_controllers, serialized)
 
         self.assertEqual(context.state_group, d_context.state_group)
         self.assertEqual(context.rejected, d_context.rejected)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index ec00900621..500c9ccfbc 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
         hs = self.setup_test_homeserver(federation_http_client=None)
         self.handler = hs.get_federation_handler()
         self.store = hs.get_datastores().main
-        self.state_storage = hs.get_storage().state
+        self.state_storage_controller = hs.get_storage_controllers().state
         self._event_auth_handler = hs.get_event_auth_handler()
         return hs
 
@@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
         # mapping from (type, state_key) -> state_event_id
         assert most_recent_prev_event_id is not None
         prev_state_map = self.get_success(
-            self.state_storage.get_state_ids_for_event(most_recent_prev_event_id)
+            self.state_storage_controller.get_state_ids_for_event(
+                most_recent_prev_event_id
+            )
         )
         # List of state event ID's
         prev_state_ids = list(prev_state_map.values())
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index e64b28f28b..1d5b2492c0 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
     ) -> None:
         OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
         main_store = self.hs.get_datastores().main
-        state_storage = self.hs.get_storage().state
+        state_storage_controller = self.hs.get_storage_controllers().state
 
         # create the room
         user_id = self.register_user("kermit", "test")
@@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         )
         if prev_exists_as_outlier:
             prev_event.internal_metadata.outlier = True
-            persistence = self.hs.get_storage().persistence
+            persistence = self.hs.get_storage_controllers().persistence
             self.get_success(
                 persistence.persist_event(
-                    prev_event, EventContext.for_outlier(self.hs.get_storage())
+                    prev_event,
+                    EventContext.for_outlier(self.hs.get_storage_controllers()),
                 )
             )
         else:
@@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
 
         # check that the state at that event is as expected
         state = self.get_success(
-            state_storage.get_state_ids_for_event(pulled_event.event_id)
+            state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
         )
         expected_state = {
             (e.type, e.state_key): e.event_id for e in state_at_prev_event
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index f4f7ab4845..44da96c792 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.handler = self.hs.get_event_creation_handler()
-        self.persist_event_storage = self.hs.get_storage().persistence
+        self._persist_event_storage_controller = (
+            self.hs.get_storage_controllers().persistence
+        )
 
         self.user_id = self.register_user("tester", "foobar")
         self.access_token = self.login("tester", "foobar")
@@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+            self._persist_event_storage_controller.persist_event(
+                memberEvent, memberEventContext
+            )
         )
 
         return memberEvent, memberEventContext
@@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self.assertNotEqual(event1.event_id, event3.event_id)
 
         ret_event3, event_pos3, _ = self.get_success(
-            self.persist_event_storage.persist_event(event3, context)
+            self._persist_event_storage_controller.persist_event(event3, context)
         )
 
         # Assert that the returned values match those from the initial event
@@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self.assertNotEqual(event1.event_id, event3.event_id)
 
         events, _ = self.get_success(
-            self.persist_event_storage.persist_events([(event3, context)])
+            self._persist_event_storage_controller.persist_events([(event3, context)])
         )
         ret_event4 = events[0]
 
@@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self.assertNotEqual(event1.event_id, event2.event_id)
 
         events, _ = self.get_success(
-            self.persist_event_storage.persist_events(
+            self._persist_event_storage_controller.persist_events(
                 [(event1, context1), (event2, context2)]
             )
         )
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 4d658d29ca..a68c2ffd45 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.hs.get_storage().persistence.persist_event(event, context)
+            self.hs.get_storage_controllers().persistence.persist_event(event, context)
         )
 
     def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 85be79d19d..c5705256e6 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
 
         self.master_store = hs.get_datastores().main
         self.slaved_store = self.worker_hs.get_datastores().main
-        self.storage = hs.get_storage()
+        self._storage_controllers = hs.get_storage_controllers()
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 297a9e77f8..6d3d4afe52 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         )
         msg, msgctx = self.build_event()
         self.get_success(
-            self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+            self._storage_controllers.persistence.persist_events(
+                [(j2, j2ctx), (msg, msgctx)]
+            )
         )
         self.replicate()
 
@@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         if backfill:
             self.get_success(
-                self.storage.persistence.persist_events(
+                self._storage_controllers.persistence.persist_events(
                     [(event, context)], backfilled=True
                 )
             )
         else:
-            self.get_success(self.storage.persistence.persist_event(event, context))
+            self.get_success(
+                self._storage_controllers.persistence.persist_event(event, context)
+            )
 
         return event
 
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index 5bbbd5fbcb..19f57115a1 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
     def prepare(self, reactor, clock, homeserver):
         super().prepare(reactor, clock, homeserver)
         self.room_creator = homeserver.get_room_creation_handler()
-        self.persist_event_storage = self.hs.get_storage().persistence
+        self.persist_event_storage_controller = (
+            self.hs.get_storage_controllers().persistence
+        )
 
         # Create a test user
         self.ourUser = UserID.from_string(OUR_USER_ID)
@@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
             )
         )
         self.get_success(
-            self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+            self.persist_event_storage_controller.persist_event(
+                memberEvent, memberEventContext
+            )
         )
 
         # Join the second user to the second room
@@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
             )
         )
         self.get_success(
-            self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+            self.persist_event_storage_controller.persist_event(
+                memberEvent, memberEventContext
+            )
         )
 
     def test_return_empty_with_no_data(self):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0cdf1dec40..0d44102237 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         event_builder_factory = self.hs.get_event_builder_factory()
         event_creation_handler = self.hs.get_event_creation_handler()
-        storage = self.hs.get_storage()
+        storage_controllers = self.hs.get_storage_controllers()
 
         # Create two rooms, one with a local user only and one with both a local
         # and remote user.
@@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
             event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(storage.persistence.persist_event(event, context))
+        self.get_success(storage_controllers.persistence.persist_event(event, context))
 
         # Now get rooms
         url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 2cd7a9e6c5..ac9c113354 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         We do this by setting a very long time between purge jobs.
         """
         store = self.hs.get_datastores().main
-        storage = self.hs.get_storage()
+        storage_controllers = self.hs.get_storage_controllers()
         room_id = self.helper.create_room_as(self.user_id, tok=self.token)
 
         # Send a first event, which should be filtered out at the end of the test.
@@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(2, len(events), "events retrieved from database")
         filtered_events = self.get_success(
-            filter_events_for_client(storage, self.user_id, events)
+            filter_events_for_client(storage_controllers, self.user_id, events)
         )
 
         # We should only get one event back.
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
index 41a1bf6d89..1b7ee08ab2 100644
--- a/tests/rest/client/test_room_batch.py
+++ b/tests/rest/client/test_room_batch.py
@@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.clock = clock
-        self.storage = hs.get_storage()
+        self._storage_controllers = hs.get_storage_controllers()
 
         self.virtual_user_id, _ = self.register_appservice_user(
             "as_user_potato", self.appservice.token
@@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
 
         # Fetch the state_groups
         state_group_map = self.get_success(
-            self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
+            self._storage_controllers.state.get_state_groups_ids(
+                room_id, historical_event_ids
+            )
         )
 
         # We expect all of the historical events to be using the same state_group
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c7661e7186..a0ce077a99 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
             # We need to persist the events to the events and state_events
             # tables.
             persist_events_store._store_event_txn(
-                txn, [(e, EventContext(self.hs.get_storage())) for e in events]
+                txn,
+                [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
             )
 
             # Actually call the function that calculates the auth chain stuff.
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index aaa3189b16..a76718e8f9 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, homeserver):
         self.state = self.hs.get_state_handler()
-        self.persistence = self.hs.get_storage().persistence
+        self._persistence = self.hs.get_storage_controllers().persistence
         self.store = self.hs.get_datastores().main
 
         self.register_user("user", "pass")
@@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         context = self.get_success(
             self.state.compute_event_context(event, state_ids_before_event=state)
         )
-        self.get_success(self.persistence.persist_event(event, context))
+        self.get_success(self._persistence.persist_event(event, context))
 
     def assert_extremities(self, expected_extremities):
         """Assert the current extremities for the room"""
@@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
             )
         )
 
-        self.get_success(self.persistence.persist_event(remote_event_2, context))
+        self.get_success(self._persistence.persist_event(remote_event_2, context))
 
         # Check that we haven't dropped the old extremity.
         self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, homeserver):
         self.state = self.hs.get_state_handler()
-        self.persistence = self.hs.get_storage().persistence
+        self._persistence = self.hs.get_storage_controllers().persistence
         self.store = self.hs.get_datastores().main
 
     def test_remote_user_rooms_cache_invalidated(self):
@@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
         )
 
         context = self.get_success(self.state.compute_event_context(remote_event_1))
-        self.get_success(self.persistence.persist_event(remote_event_1, context))
+        self.get_success(self._persistence.persist_event(remote_event_1, context))
 
         # Call `get_rooms_for_user` to add the remote user to the cache
         rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
@@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
         )
 
         context = self.get_success(self.state.compute_event_context(remote_event_1))
-        self.get_success(self.persistence.persist_event(remote_event_1, context))
+        self.get_success(self._persistence.persist_event(remote_event_1, context))
 
         # Call `get_users_in_room` to add the remote user to the cache
         users = self.get_success(self.store.get_users_in_room(room_id))
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 08cc60237e..92cd0dfc05 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
         self.room_id = self.helper.create_room_as(self.user_id)
 
         self.store = hs.get_datastores().main
-        self.storage = self.hs.get_storage()
+        self._storage_controllers = self.hs.get_storage_controllers()
 
     def test_purge_history(self):
         """
@@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
 
         # Purge everything before this topological token
         self.get_success(
-            self.storage.purge_events.purge_history(self.room_id, token_str, True)
+            self._storage_controllers.purge_events.purge_history(
+                self.room_id, token_str, True
+            )
         )
 
         # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
@@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
 
         # Purge everything before this topological token
         f = self.get_failure(
-            self.storage.purge_events.purge_history(self.room_id, event, True),
+            self._storage_controllers.purge_events.purge_history(
+                self.room_id, event, True
+            ),
             SynapseError,
         )
         self.assertIn("greater than forward", f.value.args[0])
@@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase):
         self.assertIsNotNone(create_event)
 
         # Purge everything before this topological token
-        self.get_success(self.storage.purge_events.purge_room(self.room_id))
+        self.get_success(
+            self._storage_controllers.purge_events.purge_room(self.room_id)
+        )
 
         # The events aren't found.
         self.store._invalidate_get_event_cache(create_event.event_id)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d8d17ef379..6c4e63b77c 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
+        self._storage = hs.get_storage_controllers()
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.storage.persistence.persist_event(event, context))
+        self.get_success(self._storage.persistence.persist_event(event, context))
 
         return event
 
@@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.storage.persistence.persist_event(event, context))
+        self.get_success(self._storage.persistence.persist_event(event, context))
 
         return event
 
@@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.storage.persistence.persist_event(event, context))
+        self.get_success(self._storage.persistence.persist_event(event, context))
 
         return event
 
@@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.get_success(self.storage.persistence.persist_event(event_1, context_1))
+        self.get_success(self._storage.persistence.persist_event(event_1, context_1))
 
         event_2, context_2 = self.get_success(
             self.event_creation_handler.create_new_client_event(
@@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 )
             )
         )
-        self.get_success(self.storage.persistence.persist_event(event_2, context_2))
+        self.get_success(self._storage.persistence.persist_event(event_2, context_2))
 
         # fetch one of the redactions
         fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.storage.persistence.persist_event(redaction_event, context)
+            self._storage.persistence.persist_event(redaction_event, context)
         )
 
         # Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 5b011e18cd..d497a19f63 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
+        self._storage = hs.get_storage_controllers()
         self.event_factory = hs.get_event_factory()
 
         self.room = RoomID.from_string("!abcde:test")
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
 
     def inject_room_event(self, **kwargs):
         self.get_success(
-            self.storage.persistence.persist_event(
+            self._storage.persistence.persist_event(
                 self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
             )
         )
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 8dfc1e1db9..e747c6b50e 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
         prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
         prev_event = self.get_success(store.get_event(prev_event_ids[0]))
         prev_state_map = self.get_success(
-            self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
+            self.hs.get_storage_controllers().state.get_state_ids_for_event(
+                prev_event_ids[0]
+            )
         )
 
         event_dict = {
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index f88f1c55fc..8043bdbde2 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 class StateStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
+        self.storage = hs.get_storage_controllers()
         self.state_datastore = self.storage.state.stores.state
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
diff --git a/tests/test_state.py b/tests/test_state.py
index 84694d368d..95f81bebae 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -179,12 +179,12 @@ class Graph:
 class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.dummy_store = _DummyStore()
-        storage = Mock(main=self.dummy_store, state=self.dummy_store)
+        storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
         hs = Mock(
             spec_set=[
                 "config",
                 "get_datastores",
-                "get_storage",
+                "get_storage_controllers",
                 "get_auth",
                 "get_state_handler",
                 "get_clock",
@@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
         hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
-        hs.get_storage.return_value = storage
+        hs.get_storage_controllers.return_value = storage_controllers
 
         self.state = StateHandler(hs)
         self.event_id = 0
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index c654e36ee4..8027c7a856 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -70,7 +70,7 @@ async def inject_event(
     """
     event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
 
-    persistence = hs.get_storage().persistence
+    persistence = hs.get_storage_controllers().persistence
     assert persistence is not None
 
     await persistence.persist_event(event, context)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 7a9b01ef9d..f338af6c36 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         super(FilterEventsForServerTestCase, self).setUp()
         self.event_creation_handler = self.hs.get_event_creation_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
-        self.storage = self.hs.get_storage()
+        self._storage_controllers = self.hs.get_storage_controllers()
 
         self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
 
@@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             events_to_filter.append(evt)
 
         filtered = self.get_success(
-            filter_events_for_server(self.storage, "test_server", events_to_filter)
+            filter_events_for_server(
+                self._storage_controllers, "test_server", events_to_filter
+            )
         )
 
         # the result should be 5 redacted events, and 5 unredacted events.
@@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         outlier = self._inject_outlier()
         self.assertEqual(
             self.get_success(
-                filter_events_for_server(self.storage, "remote_hs", [outlier])
+                filter_events_for_server(
+                    self._storage_controllers, "remote_hs", [outlier]
+                )
             ),
             [outlier],
         )
@@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         evt = self._inject_message("@unerased:local_hs")
 
         filtered = self.get_success(
-            filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+            filter_events_for_server(
+                self._storage_controllers, "remote_hs", [outlier, evt]
+            )
         )
         self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
         self.assertEqual(filtered[0], outlier)
@@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # ... but other servers should only be able to see the outlier (the other should
         # be redacted)
         filtered = self.get_success(
-            filter_events_for_server(self.storage, "other_server", [outlier, evt])
+            filter_events_for_server(
+                self._storage_controllers, "other_server", [outlier, evt]
+            )
         )
         self.assertEqual(filtered[0], outlier)
         self.assertEqual(filtered[1].event_id, evt.event_id)
@@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         # ... and the filtering happens.
         filtered = self.get_success(
-            filter_events_for_server(self.storage, "test_server", events_to_filter)
+            filter_events_for_server(
+                self._storage_controllers, "test_server", events_to_filter
+            )
         )
 
         for i in range(0, len(events_to_filter)):
@@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         event, context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
-        self.get_success(self.storage.persistence.persist_event(event, context))
+        self.get_success(
+            self._storage_controllers.persistence.persist_event(event, context)
+        )
         return event
 
     def _inject_room_member(
@@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.storage.persistence.persist_event(event, context))
+        self.get_success(
+            self._storage_controllers.persistence.persist_event(event, context)
+        )
         return event
 
     def _inject_message(
@@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.storage.persistence.persist_event(event, context))
+        self.get_success(
+            self._storage_controllers.persistence.persist_event(event, context)
+        )
         return event
 
     def _inject_outlier(self) -> EventBase:
@@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
         event.internal_metadata.outlier = True
         self.get_success(
-            self.storage.persistence.persist_event(
-                event, EventContext.for_outlier(self.storage)
+            self._storage_controllers.persistence.persist_event(
+                event, EventContext.for_outlier(self._storage_controllers)
             )
         )
         return event
@@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
         self.assertEqual(
             self.get_success(
                 filter_events_for_client(
-                    self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+                    self.hs.get_storage_controllers(),
+                    "@user:test",
+                    [invite_event, reject_event],
                 )
             ),
             [invite_event, reject_event],
@@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
         self.assertEqual(
             self.get_success(
                 filter_events_for_client(
-                    self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+                    self.hs.get_storage_controllers(),
+                    "@other:test",
+                    [invite_event, reject_event],
                 )
             ),
             [],
diff --git a/tests/utils.py b/tests/utils.py
index d4ba3a9b99..3059c453d5 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -264,7 +264,7 @@ class MockClock:
 async def create_room(hs, room_id: str, creator_id: str):
     """Creates and persist a creation event for the given room"""
 
-    persistence_store = hs.get_storage().persistence
+    persistence_store = hs.get_storage_controllers().persistence
     store = hs.get_datastores().main
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()