summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/federation/test_federation_server.py6
-rw-r--r--tests/handlers/test_directory.py3
-rw-r--r--tests/storage/test_events.py17
-rw-r--r--tests/storage/test_purge.py5
-rw-r--r--tests/storage/test_room.py12
5 files changed, 27 insertions, 16 deletions
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index b19365b81a..413b3c9426 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
         super().prepare(reactor, clock, hs)
 
+        self._storage_controllers = hs.get_storage_controllers()
+
         # create the room
         creator_user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test")
@@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
 
         # the room should show that the new user is a member
         r = self.get_success(
-            self.hs.get_state_handler().get_current_state(self._room_id)
+            self._storage_controllers.state.get_current_state(self._room_id)
         )
         self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
 
@@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
 
         # the room should show that the new user is a member
         r = self.get_success(
-            self.hs.get_state_handler().get_current_state(self._room_id)
+            self._storage_controllers.state.get_current_state(self._room_id)
         )
         self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
 
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 11ad44223d..53d49ca896 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
+        self._storage_controllers = hs.get_storage_controllers()
 
         # Create user
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
     def _get_canonical_alias(self):
         """Get the canonical alias state of the room."""
         return self.get_success(
-            self.state_handler.get_current_state(
+            self._storage_controllers.state.get_current_state_event(
                 self.room_id, EventTypes.CanonicalAlias, ""
             )
         )
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index a76718e8f9..2ff88e64a5 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -32,6 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
         self.state = self.hs.get_state_handler()
         self._persistence = self.hs.get_storage_controllers().persistence
+        self._state_storage_controller = self.hs.get_storage_controllers().state
         self.store = self.hs.get_datastores().main
 
         self.register_user("user", "pass")
@@ -104,7 +105,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -137,7 +138,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
         # setting. The state resolution across the old and new event will then
         # include it, and so the resolved state won't match the new state.
         state_before_gap = dict(
-            self.get_success(self.state.get_current_state_ids(self.room_id))
+            self.get_success(
+                self._state_storage_controller.get_current_state_ids(self.room_id)
+            )
         )
         state_before_gap.pop(("m.room.history_visibility", ""))
 
@@ -181,7 +184,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -213,7 +216,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -255,7 +258,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -299,7 +302,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
@@ -335,7 +338,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
         )
 
         state_before_gap = self.get_success(
-            self.state.get_current_state_ids(self.room_id)
+            self._state_storage_controller.get_current_state_ids(self.room_id)
         )
 
         self.persist_event(remote_event_2, state=state_before_gap)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 92cd0dfc05..8dfaa0559b 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -102,9 +102,10 @@ class PurgeTests(HomeserverTestCase):
         first = self.helper.send(self.room_id, body="test1")
 
         # Get the current room state.
-        state_handler = self.hs.get_state_handler()
         create_event = self.get_success(
-            state_handler.get_current_state(self.room_id, "m.room.create", "")
+            self._storage_controllers.state.get_current_state_event(
+                self.room_id, "m.room.create", ""
+            )
         )
         self.assertIsNotNone(create_event)
 
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index d497a19f63..3c79dabc9f 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
         self.store = hs.get_datastores().main
-        self._storage = hs.get_storage_controllers()
+        self._storage_controllers = hs.get_storage_controllers()
         self.event_factory = hs.get_event_factory()
 
         self.room = RoomID.from_string("!abcde:test")
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
 
     def inject_room_event(self, **kwargs):
         self.get_success(
-            self._storage.persistence.persist_event(
+            self._storage_controllers.persistence.persist_event(
                 self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
             )
         )
@@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
         )
 
         state = self.get_success(
-            self.store.get_current_state(room_id=self.room.to_string())
+            self._storage_controllers.state.get_current_state(
+                room_id=self.room.to_string()
+            )
         )
 
         self.assertEqual(1, len(state))
@@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
         )
 
         state = self.get_success(
-            self.store.get_current_state(room_id=self.room.to_string())
+            self._storage_controllers.state.get_current_state(
+                room_id=self.room.to_string()
+            )
         )
 
         self.assertEqual(1, len(state))