diff --git a/synapse/state.py b/synapse/state.py
index 9db84c9b5c..8f09b7d50a 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -35,7 +35,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
class StateHandler(object):
- """ Repsonsible for doing state conflict resolution.
+ """ Responsible for doing state conflict resolution.
"""
def __init__(self, hs):
@@ -50,7 +50,7 @@ class StateHandler(object):
to update the state and b) works out what the prev_state should be.
Returns:
- Deferred: Resolved with a boolean indicating if we succesfully
+ Deferred: Resolved with a boolean indicating if we successfully
updated the state.
Raised:
@@ -83,6 +83,8 @@ class StateHandler(object):
current_state.pdu_id, current_state.origin
)
+ yield self.update_state_groups(event)
+
# TODO check current_state to see if the min power level is less
# than the power level of the user
# power_level = self._get_power_level_for_event(event)
@@ -128,6 +130,87 @@ class StateHandler(object):
defer.returnValue(is_new)
+ @defer.inlineCallbacks
+ def update_state_groups(self, event):
+ state_groups = yield self.store.get_state_groups(
+ event.prev_events
+ )
+
+ if len(state_groups) == 1 and not hasattr(event, "state_key"):
+ event.state_group = state_groups[0].group
+ event.current_state = state_groups[0].state
+ return
+
+ state = {}
+ state_sets = {}
+ for group in state_groups:
+ for s in group.state:
+ state.setdefault((s.type, s.state_key), []).add(s)
+
+ state_sets.setdefault(
+ (s.type, s.state_key),
+ set()
+ ).add(s.event_id)
+
+ unconflicted_state = {
+ k: v.pop() for k, v in state_sets.items()
+ if len(v) == 1
+ }
+
+ conflicted_state = {
+ k: state[k]
+ for k, v in state_sets.items()
+ if len(v) > 1
+ }
+
+ new_state = {}
+ new_state.update(unconflicted_state)
+ for key, events in conflicted_state.items():
+ new_state[key] = yield self.resolve(events)
+
+ if hasattr(event, "state_key"):
+ new_state[(event.type, event.state_key)] = event
+
+ event.state_group = None
+ event.current_state = new_state.values()
+
+ @defer.inlineCallbacks
+ def resolve(self, events):
+ curr_events = events
+
+ new_powers_deferreds = []
+ for e in curr_events:
+ new_powers_deferreds.append(
+ self.store.get_power_level(e.context, e.user_id)
+ )
+
+ new_powers = yield defer.gatherResults(
+ new_powers_deferreds,
+ consumeErrors=True
+ )
+
+ max_power = max([int(p) for p in new_powers])
+
+ curr_events = [
+ z[0] for z in zip(curr_events, new_powers)
+ if int(z[1]) == max_power
+ ]
+
+ if not curr_events:
+ raise RuntimeError("Max didn't get a max?")
+ elif len(curr_events) == 1:
+ defer.returnValue(curr_events[0])
+
+ # TODO: For now, just choose the one with the largest event_id.
+ defer.returnValue(
+ sorted(
+ curr_events,
+ key=lambda e: hashlib.sha1(
+ e.event_id + e.user_id + e.room_id + e.type
+ ).hexdigest()
+ )[0]
+ )
+
def _get_power_level_for_event(self, event):
# return self._persistence.get_power_level_for_user(event.room_id,
# event.sender)
|