summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/state.py86
1 files changed, 61 insertions, 25 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 8c4eeb8924..24685c6fb4 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 
 from synapse.federation.pdu_codec import encode_event_id, decode_event_id
 from synapse.util.logutils import log_function
+from synapse.federation.pdu_codec import encode_event_id
 
 from collections import namedtuple
 
@@ -130,54 +131,89 @@ class StateHandler(object):
         defer.returnValue(is_new)
 
     @defer.inlineCallbacks
+    @log_function
     def annotate_state_groups(self, event, state=None):
         if state:
             event.state_group = None
             event.old_state_events = None
-            event.state_events = state
+            event.state_events = {(s.type, s.state_key): s for s in state}
+            defer.returnValue(False)
+            return
+
+        if hasattr(event, "outlier") and event.outlier:
+            event.state_group = None
+            event.old_state_events = None
+            event.state_events = None
+            defer.returnValue(False)
             return
 
+        new_state = yield self.resolve_state_groups(event.prev_events)
+
+        event.old_state_events = new_state
+
+        if hasattr(event, "state_key"):
+            new_state[(event.type, event.state_key)] = event
+
+        event.state_group = None
+        event.state_events = new_state
+
+        defer.returnValue(hasattr(event, "state_key"))
+
+    @defer.inlineCallbacks
+    def get_current_state(self, room_id, event_type=None, state_key=""):
+        # FIXME: HACK!
+        pdus = yield self.store.get_latest_pdus_in_context(room_id)
+
+        event_ids = [encode_event_id(p.pdu_id, p.origin) for p in pdus]
+
+        res = self.resolve_state_groups(event_ids)
+
+        if event_type:
+            defer.returnValue(res.get((event_type, state_key)))
+            return
+
+        defer.returnValue(res.values())
+
+    @defer.inlineCallbacks
+    @log_function
+    def resolve_state_groups(self, event_ids):
         state_groups = yield self.store.get_state_groups(
-            event.prev_events
+            event_ids
         )
 
         state = {}
-        state_sets = {}
         for group in state_groups:
             for s in group.state:
-                state.setdefault((s.type, s.state_key), []).append(s)
-
-                state_sets.setdefault(
+                state.setdefault(
                     (s.type, s.state_key),
-                    set()
-                ).add(s.event_id)
+                    {}
+                )[s.event_id] = s
 
         unconflicted_state = {
-            k: state[k].pop() for k, v in state_sets.items()
-            if len(v) == 1
+            k: v.values()[0] for k, v in state.items()
+            if len(v.values()) == 1
         }
 
         conflicted_state = {
-            k: state[k]
-            for k, v in state_sets.items()
-            if len(v) > 1
+            k: v.values()
+            for k, v in state.items()
+            if len(v.values()) > 1
         }
 
-        new_state = {}
-        new_state.update(unconflicted_state)
-        for key, events in conflicted_state.items():
-            new_state[key] = yield self.resolve(events)
+        try:
+            new_state = {}
+            new_state.update(unconflicted_state)
+            for key, events in conflicted_state.items():
+                new_state[key] = yield self._resolve_state_events(events)
+        except:
+            logger.exception("Failed to resolve state")
+            raise
 
-        event.old_state_events = new_state
-
-        if hasattr(event, "state_key"):
-            new_state[(event.type, event.state_key)] = event
-
-        event.state_group = None
-        event.state_events = new_state
+        defer.returnValue(new_state)
 
     @defer.inlineCallbacks
-    def resolve(self, events):
+    @log_function
+    def _resolve_state_events(self, events):
         curr_events = events
 
         new_powers_deferreds = []