diff options
Diffstat (limited to 'tests/test_state.py')
-rw-r--r-- | tests/test_state.py | 138 |
1 files changed, 75 insertions, 63 deletions
diff --git a/tests/test_state.py b/tests/test_state.py index 1a11bbcee0..6454f994e3 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -67,9 +67,11 @@ class StateGroupStore(object): self._event_to_state_group = {} self._group_to_state = {} + self._event_id_to_event = {} + self._next_group = 1 - def get_state_groups(self, room_id, event_ids): + def get_state_groups_ids(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) @@ -79,22 +81,23 @@ class StateGroupStore(object): return defer.succeed(groups) def store_state_groups(self, event, context): - if context.current_state is None: + if context.current_state_ids is None: return - state_events = context.current_state - - if event.is_state(): - state_events[(event.type, event.state_key)] = event + state_events = dict(context.current_state_ids) - state_group = context.state_group - if not state_group: - state_group = self._next_group - self._next_group += 1 + self._group_to_state[context.state_group] = state_events + self._event_to_state_group[event.event_id] = context.state_group - self._group_to_state[state_group] = state_events.values() + def get_events(self, event_ids, **kwargs): + return { + e_id: self._event_id_to_event[e_id] for e_id in event_ids + if e_id in self._event_id_to_event + } - self._event_to_state_group[event.event_id] = state_group + def register_events(self, events): + for e in events: + self._event_id_to_event[e.event_id] = e class DictObj(dict): @@ -136,8 +139,10 @@ class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( spec_set=[ - "get_state_groups", + "get_state_groups_ids", "add_event_hashes", + "get_events", + "get_next_state_group", ] ) hs = Mock(spec_set=[ @@ -148,6 +153,8 @@ class StateTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() hs.get_auth.return_value = Auth(hs) + self.store.get_next_state_group.side_effect = Mock + self.state = StateHandler(hs) self.event_id = 0 @@ -187,7 +194,7 @@ class StateTestCase(unittest.TestCase): ) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids context_store = {} @@ -196,7 +203,7 @@ class StateTestCase(unittest.TestCase): store.store_state_groups(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].current_state)) + self.assertEqual(2, len(context_store["D"].prev_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -239,7 +246,9 @@ class StateTestCase(unittest.TestCase): ) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.get_events = store.get_events + store.register_events(graph.walk()) context_store = {} @@ -250,7 +259,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "C"}, - {e.event_id for e in context_store["D"].current_state.values()} + {e_id for e_id in context_store["D"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -303,7 +312,9 @@ class StateTestCase(unittest.TestCase): ) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.get_events = store.get_events + store.register_events(graph.walk()) context_store = {} @@ -314,7 +325,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"START", "A", "B", "C"}, - {e.event_id for e in context_store["E"].current_state.values()} + {e for e in context_store["E"].prev_state_ids.values()} ) @defer.inlineCallbacks @@ -384,7 +395,9 @@ class StateTestCase(unittest.TestCase): graph = Graph(nodes, edges) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids + self.store.get_events = store.get_events + store.register_events(graph.walk()) context_store = {} @@ -395,7 +408,7 @@ class StateTestCase(unittest.TestCase): self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, - {e.event_id for e in context_store["D"].current_state.values()} + {e for e in context_store["D"].prev_state_ids.values()} ) def _add_depths(self, nodes, edges): @@ -424,16 +437,11 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( - set(old_state), set(context.current_state.values()) + set(e.event_id for e in old_state), set(context.current_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): @@ -449,18 +457,10 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( - set(old_state), - set(context.current_state.values()) + set(e.event_id for e in old_state), set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) - @defer.inlineCallbacks def test_trivial_annotate_message(self): event = create_event(type="test_message", name="event") @@ -473,20 +473,15 @@ class StateTestCase(unittest.TestCase): group_name = "group_name_1" - self.store.get_state_groups.return_value = { - group_name: old_state, + self.store.get_state_groups_ids.return_value = { + group_name: {(e.type, e.state_key): e.event_id for e in old_state}, } context = yield self.state.compute_event_context(event) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( set([e.event_id for e in old_state]), - set([e.event_id for e in context.current_state.values()]) + set(context.current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -503,23 +498,18 @@ class StateTestCase(unittest.TestCase): group_name = "group_name_1" - self.store.get_state_groups.return_value = { - group_name: old_state, + self.store.get_state_groups_ids.return_value = { + group_name: {(e.type, e.state_key): e.event_id for e in old_state}, } context = yield self.state.compute_event_context(event) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( set([e.event_id for e in old_state]), - set([e.event_id for e in context.current_state.values()]) + set(context.prev_state_ids.values()) ) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): @@ -543,11 +533,16 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] + store = StateGroupStore() + store.register_events(old_state_1) + store.register_events(old_state_2) + self.store.get_events = store.get_events + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 6) + self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): @@ -571,11 +566,16 @@ class StateTestCase(unittest.TestCase): create_event(type="test4", state_key=""), ] + store = StateGroupStore() + store.register_events(old_state_1) + store.register_events(old_state_2) + self.store.get_events = store.get_events + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 6) + self.assertEqual(len(context.current_state_ids), 6) - self.assertIsNone(context.state_group) + self.assertIsNotNone(context.state_group) @defer.inlineCallbacks def test_standard_depth_conflict(self): @@ -606,9 +606,16 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=2), ] + store = StateGroupStore() + store.register_events(old_state_1) + store.register_events(old_state_2) + self.store.get_events = store.get_events + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) + self.assertEqual( + old_state_2[2].event_id, context.current_state_ids[("test1", "1")] + ) # Reverse the depth to make sure we are actually using the depths # during state resolution. @@ -625,17 +632,22 @@ class StateTestCase(unittest.TestCase): create_event(type="test1", state_key="1", depth=1), ] + store.register_events(old_state_1) + store.register_events(old_state_2) + context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) + self.assertEqual( + old_state_1[2].event_id, context.current_state_ids[("test1", "1")] + ) def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" group_name_2 = "group_name_2" - self.store.get_state_groups.return_value = { - group_name_1: old_state_1, - group_name_2: old_state_2, + self.store.get_state_groups_ids.return_value = { + group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1}, + group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2}, } return self.state.compute_event_context(event) |