summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py129
1 files changed, 127 insertions, 2 deletions
diff --git a/synapse/state.py b/synapse/state.py
index bc6b928ec7..ec98667c5d 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -18,9 +18,11 @@ from twisted.internet import defer
 
 from synapse.federation.pdu_codec import encode_event_id, decode_event_id
 from synapse.util.logutils import log_function
+from synapse.federation.pdu_codec import encode_event_id
 
 from collections import namedtuple
 
+import copy
 import logging
 import hashlib
 
@@ -35,7 +37,7 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
 class StateHandler(object):
-    """ Repsonsible for doing state conflict resolution.
+    """ Responsible for doing state conflict resolution.
     """
 
     def __init__(self, hs):
@@ -50,7 +52,7 @@ class StateHandler(object):
         to update the state and b) works out what the prev_state should be.
 
         Returns:
-            Deferred: Resolved with a boolean indicating if we succesfully
+            Deferred: Resolved with a boolean indicating if we successfully
             updated the state.
 
         Raised:
@@ -71,6 +73,7 @@ class StateHandler(object):
         # (w.r.t. to power levels)
 
         snapshot.fill_out_prev_events(event)
+        yield self.annotate_state_groups(event)
 
         current_state = snapshot.prev_state_pdu
 
@@ -124,6 +127,128 @@ class StateHandler(object):
 
         defer.returnValue(is_new)
 
+    @defer.inlineCallbacks
+    @log_function
+    def annotate_state_groups(self, event, state=None):
+        if state:
+            event.state_group = None
+            event.old_state_events = None
+            event.state_events = {(s.type, s.state_key): s for s in state}
+            defer.returnValue(False)
+            return
+
+        if hasattr(event, "outlier") and event.outlier:
+            event.state_group = None
+            event.old_state_events = None
+            event.state_events = {}
+            defer.returnValue(False)
+            return
+
+        new_state = yield self.resolve_state_groups(event.prev_events)
+
+        event.old_state_events = copy.deepcopy(new_state)
+
+        if hasattr(event, "state_key"):
+            new_state[(event.type, event.state_key)] = event
+
+        event.state_group = None
+        event.state_events = new_state
+
+        defer.returnValue(hasattr(event, "state_key"))
+
+    @defer.inlineCallbacks
+    def get_current_state(self, room_id, event_type=None, state_key=""):
+        # FIXME: HACK!
+        pdus = yield self.store.get_latest_pdus_in_context(room_id)
+
+        event_ids = [
+            encode_event_id(pdu_id, origin)
+            for pdu_id, origin, _ in pdus
+        ]
+
+        res = yield self.resolve_state_groups(event_ids)
+
+        if event_type:
+            defer.returnValue(res.get((event_type, state_key)))
+            return
+
+        defer.returnValue(res.values())
+
+    @defer.inlineCallbacks
+    @log_function
+    def resolve_state_groups(self, event_ids):
+        state_groups = yield self.store.get_state_groups(
+            event_ids
+        )
+
+        state = {}
+        for group in state_groups:
+            for s in group.state:
+                state.setdefault(
+                    (s.type, s.state_key),
+                    {}
+                )[s.event_id] = s
+
+        unconflicted_state = {
+            k: v.values()[0] for k, v in state.items()
+            if len(v.values()) == 1
+        }
+
+        conflicted_state = {
+            k: v.values()
+            for k, v in state.items()
+            if len(v.values()) > 1
+        }
+
+        try:
+            new_state = {}
+            new_state.update(unconflicted_state)
+            for key, events in conflicted_state.items():
+                new_state[key] = yield self._resolve_state_events(events)
+        except:
+            logger.exception("Failed to resolve state")
+            raise
+
+        defer.returnValue(new_state)
+
+    @defer.inlineCallbacks
+    @log_function
+    def _resolve_state_events(self, events):
+        curr_events = events
+
+        new_powers_deferreds = []
+        for e in curr_events:
+            new_powers_deferreds.append(
+                self.store.get_power_level(e.room_id, e.user_id)
+            )
+
+        new_powers = yield defer.gatherResults(
+            new_powers_deferreds,
+            consumeErrors=True
+        )
+
+        max_power = max([int(p) for p in new_powers])
+
+        curr_events = [
+            z[0] for z in zip(curr_events, new_powers)
+            if int(z[1]) == max_power
+        ]
+
+        if not curr_events:
+            raise RuntimeError("Max didn't get a max?")
+        elif len(curr_events) == 1:
+            defer.returnValue(curr_events[0])
+
+        # TODO: For now, just choose the one with the largest event_id.
+        defer.returnValue(
+            sorted(
+                curr_events,
+                key=lambda e: hashlib.sha1(
+                    e.event_id + e.user_id + e.room_id + e.type
+                ).hexdigest()
+            )[0]
+        )
+
     def _get_power_level_for_event(self, event):
         # return self._persistence.get_power_level_for_user(event.room_id,
             # event.sender)