diff --git a/synapse/state.py b/synapse/state.py
index 8f09b7d50a..9be6b716e2 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -71,6 +71,7 @@ class StateHandler(object):
# (w.r.t. to power levels)
snapshot.fill_out_prev_events(event)
+ yield self.annotate_state_groups(event)
event.prev_events = [
e for e in event.prev_events if e != event.event_id
@@ -83,8 +84,6 @@ class StateHandler(object):
current_state.pdu_id, current_state.origin
)
- yield self.update_state_groups(event)
-
# TODO check current_state to see if the min power level is less
# than the power level of the user
# power_level = self._get_power_level_for_event(event)
@@ -131,21 +130,16 @@ class StateHandler(object):
defer.returnValue(is_new)
@defer.inlineCallbacks
- def update_state_groups(self, event):
+ def annotate_state_groups(self, event):
state_groups = yield self.store.get_state_groups(
event.prev_events
)
- if len(state_groups) == 1 and not hasattr(event, "state_key"):
- event.state_group = state_groups[0].group
- event.current_state = state_groups[0].state
- return
-
state = {}
state_sets = {}
for group in state_groups:
for s in group.state:
- state.setdefault((s.type, s.state_key), []).add(s)
+ state.setdefault((s.type, s.state_key), []).append(s)
state_sets.setdefault(
(s.type, s.state_key),
@@ -153,7 +147,7 @@ class StateHandler(object):
).add(s.event_id)
unconflicted_state = {
- k: v.pop() for k, v in state_sets.items()
+ k: state[k].pop() for k, v in state_sets.items()
if len(v) == 1
}
@@ -168,11 +162,13 @@ class StateHandler(object):
for key, events in conflicted_state.items():
new_state[key] = yield self.resolve(events)
+ event.old_state_events = new_state
+
if hasattr(event, "state_key"):
new_state[(event.type, event.state_key)] = event
event.state_group = None
- event.current_state = new_state.values()
+ event.state_events = new_state
@defer.inlineCallbacks
def resolve(self, events):
|