summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py43
1 files changed, 29 insertions, 14 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 98aaa2be53..fe5f3dc84b 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -259,13 +259,37 @@ class StateHandler(object):
 
             defer.returnValue((name, state, prev_states))
 
+        new_state, prev_states = self._resolve_events(
+            state_groups.values(), event_type, state_key
+        )
+
+        if self._state_cache is not None:
+            cache = _StateCacheEntry(
+                state=new_state,
+                state_group=None,
+                ts=self.clock.time_msec()
+            )
+
+            self._state_cache[frozenset(event_ids)] = cache
+
+        defer.returnValue((None, new_state, prev_states))
+
+    def resolve_events(self, state_sets, event):
+        if event.is_state():
+            return self._resolve_events(
+                state_sets, event.type, event.state_key
+            )
+        else:
+            return self._resolve_events(state_sets)
+
+    def _resolve_events(self, state_sets, event_type=None, state_key=""):
         state = {}
-        for group, g_state in state_groups.items():
-            for s in g_state:
+        for st in state_sets:
+            for e in st:
                 state.setdefault(
-                    (s.type, s.state_key),
+                    (e.type, e.state_key),
                     {}
-                )[s.event_id] = s
+                )[e.event_id] = e
 
         unconflicted_state = {
             k: v.values()[0] for k, v in state.items()
@@ -302,16 +326,7 @@ class StateHandler(object):
         new_state = unconflicted_state
         new_state.update(resolved_state)
 
-        if self._state_cache is not None:
-            cache = _StateCacheEntry(
-                state=new_state,
-                state_group=None,
-                ts=self.clock.time_msec()
-            )
-
-            self._state_cache[frozenset(event_ids)] = cache
-
-        defer.returnValue((None, new_state, prev_states))
+        return new_state, prev_states
 
     @log_function
     def _resolve_state_events(self, conflicted_state, auth_events):