summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2014-10-15 16:06:59 +0100
committerErik Johnston <erik@matrix.org>2014-10-15 16:06:59 +0100
commite7bc1291a079224315cea5c756061ad711241be1 (patch)
treeb9f303a57a6700df30e1b985a991ca1b7d09f63c /synapse/state.py
parentAdd missing package storate.state (diff)
downloadsynapse-e7bc1291a079224315cea5c756061ad711241be1.tar.xz
Begin making auth use event.old_state_events
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py18
1 files changed, 7 insertions, 11 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 8f09b7d50a..9be6b716e2 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -71,6 +71,7 @@ class StateHandler(object):
         # (w.r.t. to power levels)
 
         snapshot.fill_out_prev_events(event)
+        yield self.annotate_state_groups(event)
 
         event.prev_events = [
             e for e in event.prev_events if e != event.event_id
@@ -83,8 +84,6 @@ class StateHandler(object):
                 current_state.pdu_id, current_state.origin
             )
 
-        yield self.update_state_groups(event)
-
         # TODO check current_state to see if the min power level is less
         # than the power level of the user
         # power_level = self._get_power_level_for_event(event)
@@ -131,21 +130,16 @@ class StateHandler(object):
         defer.returnValue(is_new)
 
     @defer.inlineCallbacks
-    def update_state_groups(self, event):
+    def annotate_state_groups(self, event):
         state_groups = yield self.store.get_state_groups(
             event.prev_events
         )
 
-        if len(state_groups) == 1 and not hasattr(event, "state_key"):
-            event.state_group = state_groups[0].group
-            event.current_state = state_groups[0].state
-            return
-
         state = {}
         state_sets = {}
         for group in state_groups:
             for s in group.state:
-                state.setdefault((s.type, s.state_key), []).add(s)
+                state.setdefault((s.type, s.state_key), []).append(s)
 
                 state_sets.setdefault(
                     (s.type, s.state_key),
@@ -153,7 +147,7 @@ class StateHandler(object):
                 ).add(s.event_id)
 
         unconflicted_state = {
-            k: v.pop() for k, v in state_sets.items()
+            k: state[k].pop() for k, v in state_sets.items()
             if len(v) == 1
         }
 
@@ -168,11 +162,13 @@ class StateHandler(object):
         for key, events in conflicted_state.items():
             new_state[key] = yield self.resolve(events)
 
+        event.old_state_events = new_state
+
         if hasattr(event, "state_key"):
             new_state[(event.type, event.state_key)] = event
 
         event.state_group = None
-        event.current_state = new_state.values()
+        event.state_events = new_state
 
     @defer.inlineCallbacks
     def resolve(self, events):