diff options
Diffstat (limited to 'tests/test_state.py')
-rw-r--r-- | tests/test_state.py | 38 |
1 files changed, 15 insertions, 23 deletions
diff --git a/tests/test_state.py b/tests/test_state.py index de2d35145a..6454f994e3 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -86,17 +86,8 @@ class StateGroupStore(object): state_events = dict(context.current_state_ids) - if event.is_state(): - state_events[(event.type, event.state_key)] = event.event_id - - state_group = context.state_group - if not state_group: - state_group = self._next_group - self._next_group += 1 - - self._group_to_state[state_group] = state_events - - self._event_to_state_group[event.event_id] = state_group + self._group_to_state[context.state_group] = state_events + self._event_to_state_group[event.event_id] = context.state_group def get_events(self, event_ids, **kwargs): return { @@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase): "get_state_groups_ids", "add_event_hashes", "get_events", + "get_next_state_group", ] ) hs = Mock(spec_set=[ @@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) + self.store.get_next_state_group.side_effect = Mock + self.state = StateHandler(hs) self.event_id = 0 @@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase): store.store_state_groups(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].current_state_ids)) + self.assertEqual(2, len(context_store["D"].prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "C"}, - {e_id for e_id in context_store["D"].current_state_ids.values()} + {e_id for e_id in context_store["D"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "B", "C"}, - {e for e in context_store["E"].current_state_ids.values()} + {e for e in context_store["E"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e for e in context_store["D"].current_state_ids.values()} + {e for e in context_store["D"].prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase): set(e.event_id for e in old_state), set(context.current_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): @@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - set(e.event_id for e in old_state), set(context.current_state_ids.values()) + set(e.event_id for e in old_state), set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) - @defer.inlineCallbacks def test_trivial_annotate_message(self): event = create_event(type="test_message", name="event") @@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase): self.assertEqual( set([e.event_id for e in old_state]), - set(context.current_state_ids.values()) + set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): @@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase): self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): @@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase): self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): |