summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py26
1 files changed, 20 insertions, 6 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 1bca0f8f78..d0f76dc4f5 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -75,7 +75,8 @@ class StateHandler(object):
         self._state_cache.start()
 
     @defer.inlineCallbacks
-    def get_current_state(self, room_id, event_type=None, state_key=""):
+    def get_current_state(self, room_id, event_type=None, state_key="",
+                          latest_event_ids=None):
         """ Retrieves the current state for the room. This is done by
         calling `get_latest_events_in_room` to get the leading edges of the
         event graph and then resolving any of the state conflicts.
@@ -89,9 +90,10 @@ class StateHandler(object):
         Returns:
             map from (type, state_key) to event
         """
-        event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+        if not latest_event_ids:
+            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
-        res = yield self.resolve_state_groups(room_id, event_ids)
+        res = yield self.resolve_state_groups(room_id, latest_event_ids)
         state = res[1]
 
         if event_type:
@@ -212,7 +214,7 @@ class StateHandler(object):
 
         if self._state_cache is not None:
             cache = self._state_cache.get(group_names, None)
-            if cache and cache.state_group:
+            if cache:
                 cache.ts = self.clock.time_msec()
 
                 event_dict = yield self.store.get_events(cache.state.values())
@@ -228,22 +230,34 @@ class StateHandler(object):
                     (cache.state_group, state, prev_states)
                 )
 
+        logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
+
         new_state, prev_states = self._resolve_events(
             state_groups.values(), event_type, state_key
         )
 
+        state_group = None
+        new_state_event_ids = frozenset(e.event_id for e in new_state.values())
+        for sg, events in state_groups.items():
+            if new_state_event_ids == frozenset(e.event_id for e in events):
+                state_group = sg
+                break
+
         if self._state_cache is not None:
             cache = _StateCacheEntry(
                 state={key: event.event_id for key, event in new_state.items()},
-                state_group=None,
+                state_group=state_group,
                 ts=self.clock.time_msec()
             )
 
             self._state_cache[group_names] = cache
 
-        defer.returnValue((None, new_state, prev_states))
+        defer.returnValue((state_group, new_state, prev_states))
 
     def resolve_events(self, state_sets, event):
+        logger.info(
+            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+        )
         if event.is_state():
             return self._resolve_events(
                 state_sets, event.type, event.state_key