diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0e8fa93e1f..fa40af6933 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -44,11 +44,7 @@ class StateStore(SQLBaseStore):
"""
@defer.inlineCallbacks
- def get_state_groups(self, room_id, event_ids):
- """ Get the state groups for the given list of event_ids
-
- The return value is a dict mapping group names to lists of events.
- """
+ def get_state_groups_ids(self, room_id, event_ids):
if not event_ids:
defer.returnValue({})
@@ -59,9 +55,32 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
+ defer.returnValue(group_to_state)
+
+ @defer.inlineCallbacks
+ def get_state_groups(self, room_id, event_ids):
+ """ Get the state groups for the given list of event_ids
+
+ The return value is a dict mapping group names to lists of events.
+ """
+ if not event_ids:
+ defer.returnValue({})
+
+ group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+
+ state_event_map = yield self.get_events(
+ [
+ ev_id for group_ids in group_to_ids.values()
+ for ev_id in group_ids.values()
+ ],
+ get_prev_content=False
+ )
+
defer.returnValue({
- group: state_map.values()
- for group, state_map in group_to_state.items()
+ group: [
+ state_event_map[v] for v in event_id_map.values() if v in state_event_map
+ ]
+ for group, event_id_map in group_to_ids.items()
})
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
@@ -248,8 +267,17 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups, types)
+ state_event_map = yield self.get_events(
+ [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+ get_prev_content=False
+ )
+
event_to_state = {
- event_id: group_to_state[group]
+ event_id: {
+ k: state_event_map[v]
+ for k, v in group_to_state[group].items()
+ if v in state_event_map
+ }
for event_id, group in event_to_groups.items()
}
@@ -428,20 +456,13 @@ class StateStore(SQLBaseStore):
full=(types is None),
)
- state_events = yield self._get_events(
- [ev_id for sd in results.values() for ev_id in sd.values()],
- get_prev_content=False
- )
-
- state_events = {e.event_id: e for e in state_events}
-
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
for group, state_dict in results.items():
results[group] = {
- key: state_events[event_id]
+ key: event_id
for key, event_id in state_dict.items()
- if event_id and event_id in state_events
+ if event_id
}
defer.returnValue(results)
|