diff --git a/tests/test_state.py b/tests/test_state.py
index c0f2d1152d..429a18cbf7 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- self.assertEqual(2, len(context_store["D"].prev_state_ids))
+ prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ self.assertEqual(2, len(prev_state_ids))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
@@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
+ prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
self.assertSetEqual(
{"START", "A", "C"},
- {e_id for e_id in context_store["D"].prev_state_ids.values()}
+ {e_id for e_id in prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
+ prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+
self.assertSetEqual(
{"START", "A", "B", "C"},
- {e for e in context_store["E"].prev_state_ids.values()}
+ {e for e in prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
+ prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
- {e for e in context_store["D"].prev_state_ids.values()}
+ {e for e in prev_state_ids.values()}
)
def _add_depths(self, nodes, edges):
@@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
self.assertEqual(
- set(e.event_id for e in old_state), set(context.current_state_ids.values())
+ set(e.event_id for e in old_state), set(current_state_ids.values())
)
self.assertIsNotNone(context.state_group)
@@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
self.assertEqual(
- set(e.event_id for e in old_state), set(context.prev_state_ids.values())
+ set(e.event_id for e in old_state), set(prev_state_ids.values())
)
@defer.inlineCallbacks
@@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
self.assertEqual(
set([e.event_id for e in old_state]),
- set(context.current_state_ids.values())
+ set(current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
self.assertEqual(
set([e.event_id for e in old_state]),
- set(context.prev_state_ids.values())
+ set(prev_state_ids.values())
)
self.assertIsNotNone(context.state_group)
@@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
- self.assertEqual(len(context.current_state_ids), 6)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
+ self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group)
@@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
- self.assertEqual(len(context.current_state_ids), 6)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
+ self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group)
@@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
self.assertEqual(
- old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
+ old_state_2[3].event_id, current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths
@@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
self.assertEqual(
- old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
+ old_state_1[3].event_id, current_state_ids[("test1", "1")]
)
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
|