summary refs log tree commit diff
path: root/tests/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_state.py')
-rw-r--r--tests/test_state.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index c6baea3d76..95f81bebae 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -179,12 +179,12 @@ class Graph:
 class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.dummy_store = _DummyStore()
-        storage = Mock(main=self.dummy_store, state=self.dummy_store)
+        storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
         hs = Mock(
             spec_set=[
                 "config",
                 "get_datastores",
-                "get_storage",
+                "get_storage_controllers",
                 "get_auth",
                 "get_state_handler",
                 "get_clock",
@@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
         hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
-        hs.get_storage.return_value = storage
+        hs.get_storage_controllers.return_value = storage_controllers
 
         self.state = StateHandler(hs)
         self.event_id = 0
@@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
         ]
 
         context = yield defer.ensureDeferred(
-            self.state.compute_event_context(event, old_state=old_state)
+            self.state.compute_event_context(
+                event,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in old_state
+                },
+            )
         )
 
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
         ]
 
         context = yield defer.ensureDeferred(
-            self.state.compute_event_context(event, old_state=old_state)
+            self.state.compute_event_context(
+                event,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in old_state
+                },
+            )
         )
 
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())