summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py31
1 files changed, 23 insertions, 8 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 147416fd81..a0f807e3b9 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -128,7 +128,7 @@ class StateHandler(object):
     def get_current_user_in_room(self, room_id):
         latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
         group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
-        joined_users = yield self.store.get_joined_users_from_context(
+        joined_users = yield self.store.get_joined_users_from_state(
             room_id, group, state_ids
         )
         defer.returnValue(joined_users)
@@ -154,27 +154,38 @@ class StateHandler(object):
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
             if old_state:
-                context.current_state_ids = {
+                context.prev_state_ids = {
                     (s.type, s.state_key): s.event_id for s in old_state
                 }
+                if event.is_state():
+                    context.current_state_events = dict(context.prev_state_ids)
+                    key = (event.type, event.state_key)
+                    context.current_state_events[key] = event.event_id
+                else:
+                    context.current_state_events = context.prev_state_ids
             else:
                 context.current_state_ids = {}
+                context.prev_state_ids = {}
             context.prev_state_events = []
             context.state_group = self.store.get_next_state_group()
             defer.returnValue(context)
 
         if old_state:
-            context.current_state_ids = {
+            context.prev_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
             context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
-                if key in context.current_state_ids:
-                    replaces = context.current_state_ids[key]
+                if key in context.prev_state_ids:
+                    replaces = context.prev_state_ids[key]
                     if replaces != event.event_id:  # Paranoia check
                         event.unsigned["replaces_state"] = replaces
+                context.current_state_ids = dict(context.prev_state_ids)
+                context.current_state_ids[key] = event.event_id
+            else:
+                context.current_state_ids = context.prev_state_ids
 
             context.prev_state_events = []
             defer.returnValue(context)
@@ -192,7 +203,7 @@ class StateHandler(object):
 
         group, curr_state = ret
 
-        context.current_state_ids = curr_state
+        context.prev_state_ids = curr_state
         if event.is_state() or group is None:
             context.state_group = self.store.get_next_state_group()
         else:
@@ -200,9 +211,13 @@ class StateHandler(object):
 
         if event.is_state():
             key = (event.type, event.state_key)
-            if key in context.current_state_ids:
-                replaces = context.current_state_ids[key]
+            if key in context.prev_state_ids:
+                replaces = context.prev_state_ids[key]
                 event.unsigned["replaces_state"] = replaces
+            context.current_state_ids = dict(context.prev_state_ids)
+            context.current_state_ids[key] = event.event_id
+        else:
+            context.current_state_ids = context.prev_state_ids
 
         context.prev_state_events = []
         defer.returnValue(context)