diff options
-rw-r--r-- | synapse/state.py | 66 | ||||
-rw-r--r-- | tests/test_state.py | 108 |
2 files changed, 39 insertions, 135 deletions
diff --git a/synapse/state.py b/synapse/state.py index 38adde4dc9..61b14b939f 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -22,7 +22,6 @@ from synapse.api.constants import EventTypes from collections import namedtuple -import copy import logging import hashlib @@ -44,71 +43,6 @@ class StateHandler(object): self.store = hs.get_datastore() @defer.inlineCallbacks - @log_function - def annotate_event_with_state(self, event, old_state=None): - """ Annotates the event with the current state events as of that event. - - This method adds three new attributes to the event: - * `state_events`: The state up to and including the event. Encoded - as a dict mapping tuple (type, state_key) -> event. - * `old_state_events`: The state up to, but excluding, the event. - Encoded similarly as `state_events`. - * `state_group`: If there is an existing state group that can be - used, then return that. Otherwise return `None`. See state - storage for more information. - - If the argument `old_state` is given (in the form of a list of - events), then they are used as a the values for `old_state_events` and - the value for `state_events` is generated from it. `state_group` is - set to None. - - This needs to be called before persisting the event. - """ - yield run_on_reactor() - - if old_state: - event.state_group = None - event.old_state_events = { - (s.type, s.state_key): s for s in old_state - } - event.state_events = event.old_state_events - - if hasattr(event, "state_key"): - event.state_events[(event.type, event.state_key)] = event - - defer.returnValue(False) - return - - if hasattr(event, "outlier") and event.outlier: - event.state_group = None - event.old_state_events = None - event.state_events = None - defer.returnValue(False) - return - - ids = [e for e, _ in event.prev_events] - - ret = yield self.resolve_state_groups(ids) - state_group, new_state, _ = ret - - event.old_state_events = copy.deepcopy(new_state) - - if hasattr(event, "state_key"): - key = (event.type, event.state_key) - if key in new_state: - event.replaces_state = new_state[key].event_id - new_state[key] = event - elif state_group: - event.state_group = state_group - event.state_events = new_state - defer.returnValue(False) - - event.state_group = None - event.state_events = new_state - - defer.returnValue(hasattr(event, "state_key")) - - @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): """ Returns the current state for the room as a list. This is done by calling `get_latest_events_in_room` to get the leading edges of the diff --git a/tests/test_state.py b/tests/test_state.py index 7979b54a35..197e35f140 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -26,6 +26,7 @@ class StateTestCase(unittest.TestCase): self.store = Mock( spec_set=[ "get_state_groups", + "add_event_hashes", ] ) hs = Mock(spec=["get_datastore"]) @@ -37,6 +38,7 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_annotate_with_old_message(self): event = self.create_event(type="test_message", name="event") + context = Mock() old_state = [ self.create_event(type="test1", state_key="1"), @@ -44,21 +46,25 @@ class StateTestCase(unittest.TestCase): self.create_event(type="test2", state_key=""), ] - yield self.state.annotate_event_with_state(event, old_state=old_state) + yield self.state.annotate_context_with_state( + event, context, old_state=old_state + ) - for k, v in event.old_state_events.items(): + for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) - self.assertEqual(set(old_state), set(event.old_state_events.values())) - self.assertDictEqual(event.old_state_events, event.state_events) + self.assertEqual( + set(old_state), set(context.current_state.values()) + ) - self.assertIsNone(event.state_group) + self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): event = self.create_event(type="state", state_key="", name="event") + context = Mock() old_state = [ self.create_event(type="test1", state_key="1"), @@ -66,26 +72,27 @@ class StateTestCase(unittest.TestCase): self.create_event(type="test2", state_key=""), ] - yield self.state.annotate_event_with_state(event, old_state=old_state) + yield self.state.annotate_context_with_state( + event, context, old_state=old_state + ) - for k, v in event.old_state_events.items(): + for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( - set(old_state + [event]), - set(event.old_state_events.values()) + set(old_state), + set(context.current_state.values()) ) - self.assertDictEqual(event.old_state_events, event.state_events) - - self.assertIsNone(event.state_group) + self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_trivial_annotate_message(self): event = self.create_event(type="test_message", name="event") event.prev_events = [] + context = Mock() old_state = [ self.create_event(type="test1", state_key="1"), @@ -99,35 +106,25 @@ class StateTestCase(unittest.TestCase): group_name: old_state, } - yield self.state.annotate_event_with_state(event) + yield self.state.annotate_context_with_state(event, context) - for k, v in event.old_state_events.items(): + for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), - set([e.event_id for e in event.old_state_events.values()]) + set([e.event_id for e in context.current_state.values()]) ) - self.assertDictEqual( - { - k: v.event_id - for k, v in event.old_state_events.items() - }, - { - k: v.event_id - for k, v in event.state_events.items() - } - ) - - self.assertEqual(group_name, event.state_group) + self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): event = self.create_event(type="state", state_key="", name="event") event.prev_events = [] + context = Mock() old_state = [ self.create_event(type="test1", state_key="1"), @@ -141,43 +138,25 @@ class StateTestCase(unittest.TestCase): group_name: old_state, } - yield self.state.annotate_event_with_state(event) + yield self.state.annotate_context_with_state(event, context) - for k, v in event.old_state_events.items(): + for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), - set([e.event_id for e in event.old_state_events.values()]) - ) - - self.assertEqual( - set([e.event_id for e in old_state] + [event.event_id]), - set([e.event_id for e in event.state_events.values()]) - ) - - new_state = { - k: v.event_id - for k, v in event.state_events.items() - } - old_state = { - k: v.event_id - for k, v in event.old_state_events.items() - } - old_state[(event.type, event.state_key)] = event.event_id - self.assertDictEqual( - old_state, - new_state + set([e.event_id for e in context.current_state.values()]) ) - self.assertIsNone(event.state_group) + self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): event = self.create_event(type="test_message", name="event") event.prev_events = [] + context = Mock() old_state_1 = [ self.create_event(type="test1", state_key="1"), @@ -199,21 +178,17 @@ class StateTestCase(unittest.TestCase): group_name_2: old_state_2, } - yield self.state.annotate_event_with_state(event) + yield self.state.annotate_context_with_state(event, context) - self.assertEqual(len(event.old_state_events), 5) - - self.assertEqual( - set([e.event_id for e in event.state_events.values()]), - set([e.event_id for e in event.old_state_events.values()]) - ) + self.assertEqual(len(context.current_state), 5) - self.assertIsNone(event.state_group) + self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): event = self.create_event(type="test4", state_key="", name="event") event.prev_events = [] + context = Mock() old_state_1 = [ self.create_event(type="test1", state_key="1"), @@ -235,19 +210,11 @@ class StateTestCase(unittest.TestCase): group_name_2: old_state_2, } - yield self.state.annotate_event_with_state(event) + yield self.state.annotate_context_with_state(event, context) - self.assertEqual(len(event.old_state_events), 5) + self.assertEqual(len(context.current_state), 5) - expected_new = event.old_state_events - expected_new[(event.type, event.state_key)] = event - - self.assertEqual( - set([e.event_id for e in expected_new.values()]), - set([e.event_id for e in event.state_events.values()]), - ) - - self.assertIsNone(event.state_group) + self.assertIsNone(context.state_group) def create_event(self, name=None, type=None, state_key=None): self.event_id += 1 @@ -266,6 +233,9 @@ class StateTestCase(unittest.TestCase): event.state_key = state_key event.event_id = event_id + event.is_state = lambda: (state_key is not None) + event.unsigned = {} + event.user_id = "@user_id:example.com" event.room_id = "!room_id:example.com" |