diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/events/test_snapshot.py | 4 | ||||
-rw-r--r-- | tests/handlers/test_federation.py | 6 | ||||
-rw-r--r-- | tests/handlers/test_federation_event.py | 9 | ||||
-rw-r--r-- | tests/handlers/test_message.py | 14 | ||||
-rw-r--r-- | tests/handlers/test_user_directory.py | 2 | ||||
-rw-r--r-- | tests/replication/slave/storage/_base.py | 2 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_events.py | 10 | ||||
-rw-r--r-- | tests/replication/slave/storage/test_receipts.py | 12 | ||||
-rw-r--r-- | tests/rest/admin/test_user.py | 4 | ||||
-rw-r--r-- | tests/rest/client/test_retention.py | 4 | ||||
-rw-r--r-- | tests/rest/client/test_room_batch.py | 6 | ||||
-rw-r--r-- | tests/storage/test_event_chain.py | 3 | ||||
-rw-r--r-- | tests/storage/test_events.py | 12 | ||||
-rw-r--r-- | tests/storage/test_purge.py | 14 | ||||
-rw-r--r-- | tests/storage/test_redaction.py | 14 | ||||
-rw-r--r-- | tests/storage/test_room.py | 4 | ||||
-rw-r--r-- | tests/storage/test_room_search.py | 4 | ||||
-rw-r--r-- | tests/storage/test_state.py | 2 | ||||
-rw-r--r-- | tests/test_state.py | 6 | ||||
-rw-r--r-- | tests/test_utils/event_injection.py | 2 | ||||
-rw-r--r-- | tests/test_visibility.py | 46 | ||||
-rw-r--r-- | tests/utils.py | 2 |
22 files changed, 115 insertions, 67 deletions
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index defbc68c18..8ddce83b83 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.user_id = self.register_user("u1", "pass") self.user_tok = self.login("u1", "pass") @@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase): def _check_serialize_deserialize(self, event, context): serialized = self.get_success(context.serialize(event, self.store)) - d_context = EventContext.deserialize(self.storage, serialized) + d_context = EventContext.deserialize(self._storage_controllers, serialized) self.assertEqual(context.state_group, d_context.state_group) self.assertEqual(context.rejected, d_context.rejected) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index ec00900621..500c9ccfbc 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self.state_storage_controller = hs.get_storage_controllers().state self._event_auth_handler = hs.get_event_auth_handler() return hs @@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # mapping from (type, state_key) -> state_event_id assert most_recent_prev_event_id is not None prev_state_map = self.get_success( - self.state_storage.get_state_ids_for_event(most_recent_prev_event_id) + self.state_storage_controller.get_state_ids_for_event( + most_recent_prev_event_id + ) ) # List of state event ID's prev_state_ids = list(prev_state_map.values()) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index e64b28f28b..1d5b2492c0 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) -> None: OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" main_store = self.hs.get_datastores().main - state_storage = self.hs.get_storage().state + state_storage_controller = self.hs.get_storage_controllers().state # create the room user_id = self.register_user("kermit", "test") @@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) if prev_exists_as_outlier: prev_event.internal_metadata.outlier = True - persistence = self.hs.get_storage().persistence + persistence = self.hs.get_storage_controllers().persistence self.get_success( persistence.persist_event( - prev_event, EventContext.for_outlier(self.hs.get_storage()) + prev_event, + EventContext.for_outlier(self.hs.get_storage_controllers()), ) ) else: @@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # check that the state at that event is as expected state = self.get_success( - state_storage.get_state_ids_for_event(pulled_event.event_id) + state_storage_controller.get_state_ids_for_event(pulled_event.event_id) ) expected_state = { (e.type, e.state_key): e.event_id for e in state_at_prev_event diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index f4f7ab4845..44da96c792 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.handler = self.hs.get_event_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self._persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") @@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self._persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) return memberEvent, memberEventContext @@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( - self.persist_event_storage.persist_event(event3, context) + self._persist_event_storage_controller.persist_event(event3, context) ) # Assert that the returned values match those from the initial event @@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events([(event3, context)]) + self._persist_event_storage_controller.persist_events([(event3, context)]) ) ret_event4 = events[0] @@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event2.event_id) events, _ = self.get_success( - self.persist_event_storage.persist_events( + self._persist_event_storage_controller.persist_events( [(event1, context1), (event2, context2)] ) ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 4d658d29ca..a68c2ffd45 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) + self.hs.get_storage_controllers().persistence.persist_event(event, context) ) def test_local_user_leaving_room_remains_in_user_directory(self) -> None: diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 85be79d19d..c5705256e6 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): self.master_store = hs.get_datastores().main self.slaved_store = self.worker_hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() def replicate(self): """Tell the master side of replication that something has happened, and then diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 297a9e77f8..6d3d4afe52 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) msg, msgctx = self.build_event() self.get_success( - self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)]) + self._storage_controllers.persistence.persist_events( + [(j2, j2ctx), (msg, msgctx)] + ) ) self.replicate() @@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if backfill: self.get_success( - self.storage.persistence.persist_events( + self._storage_controllers.persistence.persist_events( [(event, context)], backfilled=True ) ) else: - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index 5bbbd5fbcb..19f57115a1 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): def prepare(self, reactor, clock, homeserver): super().prepare(reactor, clock, homeserver) self.room_creator = homeserver.get_room_creation_handler() - self.persist_event_storage = self.hs.get_storage().persistence + self.persist_event_storage_controller = ( + self.hs.get_storage_controllers().persistence + ) # Create a test user self.ourUser = UserID.from_string(OUR_USER_ID) @@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) # Join the second user to the second room @@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ) ) self.get_success( - self.persist_event_storage.persist_event(memberEvent, memberEventContext) + self.persist_event_storage_controller.persist_event( + memberEvent, memberEventContext + ) ) def test_return_empty_with_no_data(self): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 0cdf1dec40..0d44102237 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() # Create two rooms, one with a local user only and one with both a local # and remote user. @@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): event_creation_handler.create_new_client_event(builder) ) - self.get_success(storage.persistence.persist_event(event, context)) + self.get_success(storage_controllers.persistence.persist_event(event, context)) # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 2cd7a9e6c5..ac9c113354 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): We do this by setting a very long time between purge jobs. """ store = self.hs.get_datastores().main - storage = self.hs.get_storage() + storage_controllers = self.hs.get_storage_controllers() room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Send a first event, which should be filtered out at the end of the test. @@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): ) self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( - filter_events_for_client(storage, self.user_id, events) + filter_events_for_client(storage_controllers, self.user_id, events) ) # We should only get one event back. diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 41a1bf6d89..1b7ee08ab2 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.clock = clock - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.virtual_user_id, _ = self.register_appservice_user( "as_user_potato", self.appservice.token @@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): # Fetch the state_groups state_group_map = self.get_success( - self.storage.state.get_state_groups_ids(room_id, historical_event_ids) + self._storage_controllers.state.get_state_groups_ids( + room_id, historical_event_ids + ) ) # We expect all of the historical events to be using the same state_group diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index c7661e7186..a0ce077a99 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase): # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( - txn, [(e, EventContext(self.hs.get_storage())) for e in events] + txn, + [(e, EventContext(self.hs.get_storage_controllers())) for e in events], ) # Actually call the function that calculates the auth chain stuff. diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index aaa3189b16..a76718e8f9 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main self.register_user("user", "pass") @@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase): context = self.get_success( self.state.compute_event_context(event, state_ids_before_event=state) ) - self.get_success(self.persistence.persist_event(event, context)) + self.get_success(self._persistence.persist_event(event, context)) def assert_extremities(self, expected_extremities): """Assert the current extremities for the room""" @@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) ) - self.get_success(self.persistence.persist_event(remote_event_2, context)) + self.get_success(self._persistence.persist_event(remote_event_2, context)) # Check that we haven't dropped the old extremity. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) @@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() - self.persistence = self.hs.get_storage().persistence + self._persistence = self.hs.get_storage_controllers().persistence self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self): @@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_rooms_for_user` to add the remote user to the cache rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) @@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): ) context = self.get_success(self.state.compute_event_context(remote_event_1)) - self.get_success(self.persistence.persist_event(remote_event_1, context)) + self.get_success(self._persistence.persist_event(remote_event_1, context)) # Call `get_users_in_room` to add the remote user to the cache users = self.get_success(self.store.get_users_in_room(room_id)) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 08cc60237e..92cd0dfc05 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id) self.store = hs.get_datastores().main - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() def test_purge_history(self): """ @@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase): # Purge everything before this topological token self.get_success( - self.storage.purge_events.purge_history(self.room_id, token_str, True) + self._storage_controllers.purge_events.purge_history( + self.room_id, token_str, True + ) ) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted @@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase): # Purge everything before this topological token f = self.get_failure( - self.storage.purge_events.purge_history(self.room_id, event, True), + self._storage_controllers.purge_events.purge_history( + self.room_id, event, True + ), SynapseError, ) self.assertIn("greater than forward", f.value.args[0]) @@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase): self.assertIsNotNone(create_event) # Purge everything before this topological token - self.get_success(self.storage.purge_events.purge_room(self.room_id)) + self.get_success( + self._storage_controllers.purge_events.purge_room(self.room_id) + ) # The events aren't found. self.store._invalidate_get_event_cache(create_event.event_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index d8d17ef379..6c4e63b77c 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage = hs.get_storage_controllers() self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() @@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success(self._storage.persistence.persist_event(event, context)) return event @@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) - self.get_success(self.storage.persistence.persist_event(event_1, context_1)) + self.get_success(self._storage.persistence.persist_event(event_1, context_1)) event_2, context_2 = self.get_success( self.event_creation_handler.create_new_client_event( @@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) - self.get_success(self.storage.persistence.persist_event(event_2, context_2)) + self.get_success(self._storage.persistence.persist_event(event_2, context_2)) # fetch one of the redactions fetched = self.get_success(self.store.get_event(redaction_event_id1)) @@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.storage.persistence.persist_event(redaction_event, context) + self._storage.persistence.persist_event(redaction_event, context) ) # Now lets jump to the future where we have censored the redaction event diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 5b011e18cd..d497a19f63 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage = hs.get_storage_controllers() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def inject_room_event(self, **kwargs): self.get_success( - self.storage.persistence.persist_event( + self._storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 8dfc1e1db9..e747c6b50e 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase): prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) prev_event = self.get_success(store.get_event(prev_event_ids[0])) prev_state_map = self.get_success( - self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0]) + self.hs.get_storage_controllers().state.get_state_ids_for_event( + prev_event_ids[0] + ) ) event_dict = { diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index f88f1c55fc..8043bdbde2 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self.storage = hs.get_storage_controllers() self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/test_state.py b/tests/test_state.py index 84694d368d..95f81bebae 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -179,12 +179,12 @@ class Graph: class StateTestCase(unittest.TestCase): def setUp(self): self.dummy_store = _DummyStore() - storage = Mock(main=self.dummy_store, state=self.dummy_store) + storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) hs = Mock( spec_set=[ "config", "get_datastores", - "get_storage", + "get_storage_controllers", "get_auth", "get_state_handler", "get_clock", @@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) - hs.get_storage.return_value = storage + hs.get_storage_controllers.return_value = storage_controllers self.state = StateHandler(hs) self.event_id = 0 diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index c654e36ee4..8027c7a856 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -70,7 +70,7 @@ async def inject_event( """ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - persistence = hs.get_storage().persistence + persistence = hs.get_storage_controllers().persistence assert persistence is not None await persistence.persist_event(event, context) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 7a9b01ef9d..f338af6c36 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): super(FilterEventsForServerTestCase, self).setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): events_to_filter.append(evt) filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self._storage_controllers, "test_server", events_to_filter + ) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): outlier = self._inject_outlier() self.assertEqual( self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier]) + filter_events_for_server( + self._storage_controllers, "remote_hs", [outlier] + ) ), [outlier], ) @@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): evt = self._inject_message("@unerased:local_hs") filtered = self.get_success( - filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) + filter_events_for_server( + self._storage_controllers, "remote_hs", [outlier, evt] + ) ) self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") self.assertEqual(filtered[0], outlier) @@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... but other servers should only be able to see the outlier (the other should # be redacted) filtered = self.get_success( - filter_events_for_server(self.storage, "other_server", [outlier, evt]) + filter_events_for_server( + self._storage_controllers, "other_server", [outlier, evt] + ) ) self.assertEqual(filtered[0], outlier) self.assertEqual(filtered[1].event_id, evt.event_id) @@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # ... and the filtering happens. filtered = self.get_success( - filter_events_for_server(self.storage, "test_server", events_to_filter) + filter_events_for_server( + self._storage_controllers, "test_server", events_to_filter + ) ) for i in range(0, len(events_to_filter)): @@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_room_member( @@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_message( @@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.event_creation_handler.create_new_client_event(builder) ) - self.get_success(self.storage.persistence.persist_event(event, context)) + self.get_success( + self._storage_controllers.persistence.persist_event(event, context) + ) return event def _inject_outlier(self) -> EventBase: @@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self.storage.persistence.persist_event( - event, EventContext.for_outlier(self.storage) + self._storage_controllers.persistence.persist_event( + event, EventContext.for_outlier(self._storage_controllers) ) ) return event @@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@user:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@user:test", + [invite_event, reject_event], ) ), [invite_event, reject_event], @@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual( self.get_success( filter_events_for_client( - self.hs.get_storage(), "@other:test", [invite_event, reject_event] + self.hs.get_storage_controllers(), + "@other:test", + [invite_event, reject_event], ) ), [], diff --git a/tests/utils.py b/tests/utils.py index d4ba3a9b99..3059c453d5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -264,7 +264,7 @@ class MockClock: async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room""" - persistence_store = hs.get_storage().persistence + persistence_store = hs.get_storage_controllers().persistence store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() |