diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 7a9b01ef9d..f338af6c36 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
super(FilterEventsForServerTestCase, self).setUp()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
events_to_filter.append(evt)
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
outlier = self._inject_outlier()
self.assertEqual(
self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier]
+ )
),
[outlier],
)
@@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
evt = self._inject_message("@unerased:local_hs")
filtered = self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier, evt]
+ )
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
self.assertEqual(filtered[0], outlier)
@@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... but other servers should only be able to see the outlier (the other should
# be redacted)
filtered = self.get_success(
- filter_events_for_server(self.storage, "other_server", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "other_server", [outlier, evt]
+ )
)
self.assertEqual(filtered[0], outlier)
self.assertEqual(filtered[1].event_id, evt.event_id)
@@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
for i in range(0, len(events_to_filter)):
@@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_room_member(
@@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_message(
@@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_outlier(self) -> EventBase:
@@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
- self.storage.persistence.persist_event(
- event, EventContext.for_outlier(self.storage)
+ self._storage_controllers.persistence.persist_event(
+ event, EventContext.for_outlier(self._storage_controllers)
)
)
return event
@@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@user:test",
+ [invite_event, reject_event],
)
),
[invite_event, reject_event],
@@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@other:test",
+ [invite_event, reject_event],
)
),
[],
|