diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index defbc68c18..8ddce83b83 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
@@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def _check_serialize_deserialize(self, event, context):
serialized = self.get_success(context.serialize(event, self.store))
- d_context = EventContext.deserialize(self.storage, serialized)
+ d_context = EventContext.deserialize(self._storage_controllers, serialized)
self.assertEqual(context.state_group, d_context.state_group)
self.assertEqual(context.rejected, d_context.rejected)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index ec00900621..500c9ccfbc 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
- self.state_storage = hs.get_storage().state
+ self.state_storage_controller = hs.get_storage_controllers().state
self._event_auth_handler = hs.get_event_auth_handler()
return hs
@@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# mapping from (type, state_key) -> state_event_id
assert most_recent_prev_event_id is not None
prev_state_map = self.get_success(
- self.state_storage.get_state_ids_for_event(most_recent_prev_event_id)
+ self.state_storage_controller.get_state_ids_for_event(
+ most_recent_prev_event_id
+ )
)
# List of state event ID's
prev_state_ids = list(prev_state_map.values())
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index e64b28f28b..1d5b2492c0 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) -> None:
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
- state_storage = self.hs.get_storage().state
+ state_storage_controller = self.hs.get_storage_controllers().state
# create the room
user_id = self.register_user("kermit", "test")
@@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
- persistence = self.hs.get_storage().persistence
+ persistence = self.hs.get_storage_controllers().persistence
self.get_success(
persistence.persist_event(
- prev_event, EventContext.for_outlier(self.hs.get_storage())
+ prev_event,
+ EventContext.for_outlier(self.hs.get_storage_controllers()),
)
)
else:
@@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# check that the state at that event is as expected
state = self.get_success(
- state_storage.get_state_ids_for_event(pulled_event.event_id)
+ state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
)
expected_state = {
(e.type, e.state_key): e.event_id for e in state_at_prev_event
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index f4f7ab4845..44da96c792 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.handler = self.hs.get_event_creation_handler()
- self.persist_event_storage = self.hs.get_storage().persistence
+ self._persist_event_storage_controller = (
+ self.hs.get_storage_controllers().persistence
+ )
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
@@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self._persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
return memberEvent, memberEventContext
@@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
- self.persist_event_storage.persist_event(event3, context)
+ self._persist_event_storage_controller.persist_event(event3, context)
)
# Assert that the returned values match those from the initial event
@@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
- self.persist_event_storage.persist_events([(event3, context)])
+ self._persist_event_storage_controller.persist_events([(event3, context)])
)
ret_event4 = events[0]
@@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event2.event_id)
events, _ = self.get_success(
- self.persist_event_storage.persist_events(
+ self._persist_event_storage_controller.persist_events(
[(event1, context1), (event2, context2)]
)
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 4d658d29ca..a68c2ffd45 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.hs.get_storage().persistence.persist_event(event, context)
+ self.hs.get_storage_controllers().persistence.persist_event(event, context)
)
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 85be79d19d..c5705256e6 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 297a9e77f8..6d3d4afe52 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
msg, msgctx = self.build_event()
self.get_success(
- self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+ self._storage_controllers.persistence.persist_events(
+ [(j2, j2ctx), (msg, msgctx)]
+ )
)
self.replicate()
@@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self.storage.persistence.persist_events(
+ self._storage_controllers.persistence.persist_events(
[(event, context)], backfilled=True
)
)
else:
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index 5bbbd5fbcb..19f57115a1 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
self.room_creator = homeserver.get_room_creation_handler()
- self.persist_event_storage = self.hs.get_storage().persistence
+ self.persist_event_storage_controller = (
+ self.hs.get_storage_controllers().persistence
+ )
# Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID)
@@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
# Join the second user to the second room
@@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
def test_return_empty_with_no_data(self):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0cdf1dec40..0d44102237 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
- storage = self.hs.get_storage()
+ storage_controllers = self.hs.get_storage_controllers()
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
event_creation_handler.create_new_client_event(builder)
)
- self.get_success(storage.persistence.persist_event(event, context))
+ self.get_success(storage_controllers.persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 2cd7a9e6c5..ac9c113354 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
We do this by setting a very long time between purge jobs.
"""
store = self.hs.get_datastores().main
- storage = self.hs.get_storage()
+ storage_controllers = self.hs.get_storage_controllers()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Send a first event, which should be filtered out at the end of the test.
@@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
- filter_events_for_client(storage, self.user_id, events)
+ filter_events_for_client(storage_controllers, self.user_id, events)
)
# We should only get one event back.
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
index 41a1bf6d89..1b7ee08ab2 100644
--- a/tests/rest/client/test_room_batch.py
+++ b/tests/rest/client/test_room_batch.py
@@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.clock = clock
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.virtual_user_id, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token
@@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
# Fetch the state_groups
state_group_map = self.get_success(
- self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
+ self._storage_controllers.state.get_state_groups_ids(
+ room_id, historical_event_ids
+ )
)
# We expect all of the historical events to be using the same state_group
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c7661e7186..a0ce077a99 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
- txn, [(e, EventContext(self.hs.get_storage())) for e in events]
+ txn,
+ [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
)
# Actually call the function that calculates the auth chain stuff.
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index aaa3189b16..a76718e8f9 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
self.register_user("user", "pass")
@@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
context = self.get_success(
self.state.compute_event_context(event, state_ids_before_event=state)
)
- self.get_success(self.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
def assert_extremities(self, expected_extremities):
"""Assert the current extremities for the room"""
@@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
)
- self.get_success(self.persistence.persist_event(remote_event_2, context))
+ self.get_success(self._persistence.persist_event(remote_event_2, context))
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self):
@@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_rooms_for_user` to add the remote user to the cache
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
@@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_users_in_room` to add the remote user to the cache
users = self.get_success(self.store.get_users_in_room(room_id))
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 08cc60237e..92cd0dfc05 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
def test_purge_history(self):
"""
@@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
self.get_success(
- self.storage.purge_events.purge_history(self.room_id, token_str, True)
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, token_str, True
+ )
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
@@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
f = self.get_failure(
- self.storage.purge_events.purge_history(self.room_id, event, True),
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, event, True
+ ),
SynapseError,
)
self.assertIn("greater than forward", f.value.args[0])
@@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase):
self.assertIsNotNone(create_event)
# Purge everything before this topological token
- self.get_success(self.storage.purge_events.purge_room(self.room_id))
+ self.get_success(
+ self._storage_controllers.purge_events.purge_room(self.room_id)
+ )
# The events aren't found.
self.store._invalidate_get_event_cache(create_event.event_id)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d8d17ef379..6c4e63b77c 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage = hs.get_storage_controllers()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(self.storage.persistence.persist_event(event_1, context_1))
+ self.get_success(self._storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
- self.get_success(self.storage.persistence.persist_event(event_2, context_2))
+ self.get_success(self._storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.storage.persistence.persist_event(redaction_event, context)
+ self._storage.persistence.persist_event(redaction_event, context)
)
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 5b011e18cd..d497a19f63 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage = hs.get_storage_controllers()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
def inject_room_event(self, **kwargs):
self.get_success(
- self.storage.persistence.persist_event(
+ self._storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 8dfc1e1db9..e747c6b50e 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
prev_state_map = self.get_success(
- self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
+ self.hs.get_storage_controllers().state.get_state_ids_for_event(
+ prev_event_ids[0]
+ )
)
event_dict = {
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index f88f1c55fc..8043bdbde2 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
diff --git a/tests/test_state.py b/tests/test_state.py
index 84694d368d..95f81bebae 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -179,12 +179,12 @@ class Graph:
class StateTestCase(unittest.TestCase):
def setUp(self):
self.dummy_store = _DummyStore()
- storage = Mock(main=self.dummy_store, state=self.dummy_store)
+ storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
hs = Mock(
spec_set=[
"config",
"get_datastores",
- "get_storage",
+ "get_storage_controllers",
"get_auth",
"get_state_handler",
"get_clock",
@@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
- hs.get_storage.return_value = storage
+ hs.get_storage_controllers.return_value = storage_controllers
self.state = StateHandler(hs)
self.event_id = 0
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index c654e36ee4..8027c7a856 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -70,7 +70,7 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- persistence = hs.get_storage().persistence
+ persistence = hs.get_storage_controllers().persistence
assert persistence is not None
await persistence.persist_event(event, context)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 7a9b01ef9d..f338af6c36 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
super(FilterEventsForServerTestCase, self).setUp()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
events_to_filter.append(evt)
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
outlier = self._inject_outlier()
self.assertEqual(
self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier]
+ )
),
[outlier],
)
@@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
evt = self._inject_message("@unerased:local_hs")
filtered = self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier, evt]
+ )
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
self.assertEqual(filtered[0], outlier)
@@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... but other servers should only be able to see the outlier (the other should
# be redacted)
filtered = self.get_success(
- filter_events_for_server(self.storage, "other_server", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "other_server", [outlier, evt]
+ )
)
self.assertEqual(filtered[0], outlier)
self.assertEqual(filtered[1].event_id, evt.event_id)
@@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
for i in range(0, len(events_to_filter)):
@@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_room_member(
@@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_message(
@@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_outlier(self) -> EventBase:
@@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
- self.storage.persistence.persist_event(
- event, EventContext.for_outlier(self.storage)
+ self._storage_controllers.persistence.persist_event(
+ event, EventContext.for_outlier(self._storage_controllers)
)
)
return event
@@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@user:test",
+ [invite_event, reject_event],
)
),
[invite_event, reject_event],
@@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@other:test",
+ [invite_event, reject_event],
)
),
[],
diff --git a/tests/utils.py b/tests/utils.py
index d4ba3a9b99..3059c453d5 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -264,7 +264,7 @@ class MockClock:
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room"""
- persistence_store = hs.get_storage().persistence
+ persistence_store = hs.get_storage_controllers().persistence
store = hs.get_datastores().main
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
|