summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2014-10-14 16:59:51 +0100
committerErik Johnston <erik@matrix.org>2014-10-14 16:59:51 +0100
commit5fefc12d1e2da56895d5652e3d7516ac59ab8824 (patch)
treef5e6b9a146faf0e3cd8bf7940d6e60ba89d5adc1 /synapse/state.py
parentMerge pull request #8 from matrix-org/server2server_signing (diff)
downloadsynapse-5fefc12d1e2da56895d5652e3d7516ac59ab8824.tar.xz
Begin implementing state groups.
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py87
1 files changed, 85 insertions, 2 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 9db84c9b5c..8f09b7d50a 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -35,7 +35,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
 class StateHandler(object):
-    """ Repsonsible for doing state conflict resolution.
+    """ Responsible for doing state conflict resolution.
     """
 
     def __init__(self, hs):
@@ -50,7 +50,7 @@ class StateHandler(object):
         to update the state and b) works out what the prev_state should be.
 
         Returns:
-            Deferred: Resolved with a boolean indicating if we succesfully
+            Deferred: Resolved with a boolean indicating if we successfully
             updated the state.
 
         Raised:
@@ -83,6 +83,8 @@ class StateHandler(object):
                 current_state.pdu_id, current_state.origin
             )
 
+        yield self.update_state_groups(event)
+
         # TODO check current_state to see if the min power level is less
         # than the power level of the user
         # power_level = self._get_power_level_for_event(event)
@@ -128,6 +130,87 @@ class StateHandler(object):
 
         defer.returnValue(is_new)
 
+    @defer.inlineCallbacks
+    def update_state_groups(self, event):
+        state_groups = yield self.store.get_state_groups(
+            event.prev_events
+        )
+
+        if len(state_groups) == 1 and not hasattr(event, "state_key"):
+            event.state_group = state_groups[0].group
+            event.current_state = state_groups[0].state
+            return
+
+        state = {}
+        state_sets = {}
+        for group in state_groups:
+            for s in group.state:
+                state.setdefault((s.type, s.state_key), []).add(s)
+
+                state_sets.setdefault(
+                    (s.type, s.state_key),
+                    set()
+                ).add(s.event_id)
+
+        unconflicted_state = {
+            k: v.pop() for k, v in state_sets.items()
+            if len(v) == 1
+        }
+
+        conflicted_state = {
+            k: state[k]
+            for k, v in state_sets.items()
+            if len(v) > 1
+        }
+
+        new_state = {}
+        new_state.update(unconflicted_state)
+        for key, events in conflicted_state.items():
+            new_state[key] = yield self.resolve(events)
+
+        if hasattr(event, "state_key"):
+            new_state[(event.type, event.state_key)] = event
+
+        event.state_group = None
+        event.current_state = new_state.values()
+
+    @defer.inlineCallbacks
+    def resolve(self, events):
+        curr_events = events
+
+        new_powers_deferreds = []
+        for e in curr_events:
+            new_powers_deferreds.append(
+                self.store.get_power_level(e.context, e.user_id)
+            )
+
+        new_powers = yield defer.gatherResults(
+            new_powers_deferreds,
+            consumeErrors=True
+        )
+
+        max_power = max([int(p) for p in new_powers])
+
+        curr_events = [
+            z[0] for z in zip(curr_events, new_powers)
+            if int(z[1]) == max_power
+        ]
+
+        if not curr_events:
+            raise RuntimeError("Max didn't get a max?")
+        elif len(curr_events) == 1:
+            defer.returnValue(curr_events[0])
+
+        # TODO: For now, just choose the one with the largest event_id.
+        defer.returnValue(
+            sorted(
+                curr_events,
+                key=lambda e: hashlib.sha1(
+                    e.event_id + e.user_id + e.room_id + e.type
+                ).hexdigest()
+            )[0]
+        )
+
     def _get_power_level_for_event(self, event):
         # return self._persistence.get_power_level_for_user(event.room_id,
             # event.sender)