summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/state.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 430665f7ba..8a556a27f6 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -89,7 +89,7 @@ class StateHandler(object):
         ids = [e for e, _ in event.prev_events]
 
         ret = yield self.resolve_state_groups(ids)
-        state_group, new_state = ret
+        state_group, new_state, _ = ret
 
         event.old_state_events = copy.deepcopy(new_state)
 
@@ -137,7 +137,7 @@ class StateHandler(object):
 
     @defer.inlineCallbacks
     @log_function
-    def resolve_state_groups(self, event_ids):
+    def resolve_state_groups(self, event_ids, event_type=None, state_key=""):
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
@@ -156,7 +156,10 @@ class StateHandler(object):
                 (e.type, e.state_key): e
                 for e in state_list
             }
-            defer.returnValue((name, state))
+            prev_state = state.get((event_type, state_key), None)
+            if prev_state:
+                prev_state = prev_state.event_id
+            defer.returnValue((name, state, [prev_state]))
 
         state = {}
         for group, g_state in state_groups.items():
@@ -177,6 +180,13 @@ class StateHandler(object):
             if len(v.values()) > 1
         }
 
+        if event_type:
+            prev_states = conflicted_state.get(
+                (event_type, state_key), {}
+            ).keys()
+        else:
+            prev_states = []
+
         try:
             new_state = {}
             new_state.update(unconflicted_state)
@@ -186,7 +196,7 @@ class StateHandler(object):
             logger.exception("Failed to resolve state")
             raise
 
-        defer.returnValue((None, new_state))
+        defer.returnValue((None, new_state, prev_states))
 
     def _get_power_level_from_event_state(self, event, user_id):
         if hasattr(event, "old_state_events") and event.old_state_events: