summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/state.py48
1 files changed, 27 insertions, 21 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 2249b7fffb..2a01887a67 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -107,6 +107,20 @@ class StateHandler(object):
         defer.returnValue(state)
 
     @defer.inlineCallbacks
+    def get_current_state_ids(self, room_id, event_type=None, state_key="",
+                              latest_event_ids=None):
+        if not latest_event_ids:
+            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+
+        _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
+
+        if event_type:
+            defer.returnValue(state.get((event_type, state_key)))
+            return
+
+        defer.returnValue(state)
+
+    @defer.inlineCallbacks
     def compute_event_context(self, event, old_state=None):
         """ Fills out the context with the `current state` of the graph. The
         `current state` here is defined to be the state of the event graph
@@ -127,27 +141,27 @@ class StateHandler(object):
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
             if old_state:
-                context.current_state = {
-                    (s.type, s.state_key): s for s in old_state
+                context.current_state_ids = {
+                    (s.type, s.state_key): s.event_id for s in old_state
                 }
             else:
-                context.current_state = {}
+                context.current_state_ids = {}
             context.prev_state_events = []
             context.state_group = None
             defer.returnValue(context)
 
         if old_state:
-            context.current_state = {
-                (s.type, s.state_key): s for s in old_state
+            context.current_state_ids = {
+                (s.type, s.state_key): s.event_id for s in old_state
             }
             context.state_group = None
 
             if event.is_state():
                 key = (event.type, event.state_key)
-                if key in context.current_state:
-                    replaces = context.current_state[key]
-                    if replaces.event_id != event.event_id:  # Paranoia check
-                        event.unsigned["replaces_state"] = replaces.event_id
+                if key in context.current_state_ids:
+                    replaces = context.current_state_ids[key]
+                    if replaces != event.event_id:  # Paranoia check
+                        event.unsigned["replaces_state"] = replaces
 
             context.prev_state_events = []
             defer.returnValue(context)
@@ -165,22 +179,14 @@ class StateHandler(object):
 
         group, curr_state = ret
 
-        state_map = yield self.store.get_events(
-            curr_state.values(),
-            get_prev_content=False
-        )
-        curr_state = {
-            key: state_map[e_id] for key, e_id in curr_state.items() if e_id in state_map
-        }
-
-        context.current_state = curr_state
+        context.current_state_ids = curr_state
         context.state_group = group if not event.is_state() else None
 
         if event.is_state():
             key = (event.type, event.state_key)
-            if key in context.current_state:
-                replaces = context.current_state[key]
-                event.unsigned["replaces_state"] = replaces.event_id
+            if key in context.current_state_ids:
+                replaces = context.current_state_ids[key]
+                event.unsigned["replaces_state"] = replaces
 
         context.prev_state_events = []
         defer.returnValue(context)