diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 8ed8a21b0a..f06c734c4e 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -172,7 +172,7 @@ class StateStore(SQLBaseStore):
defer.returnValue(events)
def _get_state_groups_from_groups(self, groups, types):
- """Returns dictionary state_group -> state event ids
+ """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
"""
def f(txn, groups):
if types is not None:
@@ -183,7 +183,8 @@ class StateStore(SQLBaseStore):
where_clause = ""
sql = (
- "SELECT state_group, event_id FROM state_groups_state WHERE"
+ "SELECT state_group, event_id, type, state_key"
+ " FROM state_groups_state WHERE"
" state_group IN (%s) %s" % (
",".join("?" for _ in groups),
where_clause,
@@ -199,7 +200,8 @@ class StateStore(SQLBaseStore):
results = {}
for row in rows:
- results.setdefault(row["state_group"], []).append(row["event_id"])
+ key = (row["type"], row["state_key"])
+ results.setdefault(row["state_group"], {})[key] = row["event_id"]
return results
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
@@ -296,7 +298,7 @@ class StateStore(SQLBaseStore):
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
- is_all, state_dict = self._state_group_cache.get(group)
+ is_all, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
@@ -308,7 +310,7 @@ class StateStore(SQLBaseStore):
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
- if (typ, state_key) not in state_dict:
+ if (typ, state_key) not in state_dict_ids:
missing_types.add((typ, state_key))
sentinel = object()
@@ -326,7 +328,7 @@ class StateStore(SQLBaseStore):
got_all = not (missing_types or types is None)
return {
- k: v for k, v in state_dict.items()
+ k: v for k, v in state_dict_ids.items()
if include(k[0], k[1])
}, missing_types, got_all
@@ -340,8 +342,9 @@ class StateStore(SQLBaseStore):
Args:
group: The state group to lookup
"""
- is_all, state_dict = self._state_group_cache.get(group)
- return state_dict, is_all
+ is_all, state_dict_ids = self._state_group_cache.get(group)
+
+ return state_dict_ids, is_all
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
@@ -354,84 +357,69 @@ class StateStore(SQLBaseStore):
missing_groups = []
if types is not None:
for group in set(groups):
- state_dict, missing_types, got_all = self._get_some_state_from_cache(
+ state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
group, types
)
- results[group] = state_dict
+ results[group] = state_dict_ids
if not got_all:
missing_groups.append(group)
else:
for group in set(groups):
- state_dict, got_all = self._get_all_state_from_cache(
+ state_dict_ids, got_all = self._get_all_state_from_cache(
group
)
- results[group] = state_dict
+
+ results[group] = state_dict_ids
if not got_all:
missing_groups.append(group)
- if not missing_groups:
- defer.returnValue({
- group: {
- type_tuple: event
- for type_tuple, event in state.items()
- if event
- }
- for group, state in results.items()
- })
+ if missing_groups:
+ # Okay, so we have some missing_types, lets fetch them.
+ cache_seq_num = self._state_group_cache.sequence
- # Okay, so we have some missing_types, lets fetch them.
- cache_seq_num = self._state_group_cache.sequence
+ group_to_state_dict = yield self._get_state_groups_from_groups(
+ missing_groups, types
+ )
- group_state_dict = yield self._get_state_groups_from_groups(
- missing_groups, types
- )
+ # Now we want to update the cache with all the things we fetched
+ # from the database.
+ for group, group_state_dict in group_to_state_dict.items():
+ if types:
+ # We delibrately put key -> None mappings into the cache to
+ # cache absence of the key, on the assumption that if we've
+ # explicitly asked for some types then we will probably ask
+ # for them again.
+ state_dict = {key: None for key in types}
+ state_dict.update(results[group])
+ results[group] = state_dict
+ else:
+ state_dict = results[group]
+
+ state_dict.update(group_state_dict)
+
+ self._state_group_cache.update(
+ cache_seq_num,
+ key=group,
+ value=state_dict,
+ full=(types is None),
+ )
state_events = yield self._get_events(
- [e_id for l in group_state_dict.values() for e_id in l],
+ [ev_id for sd in results.values() for ev_id in sd.values()],
get_prev_content=False
)
state_events = {e.event_id: e for e in state_events}
- # Now we want to update the cache with all the things we fetched
- # from the database.
- for group, state_ids in group_state_dict.items():
- if types:
- # We delibrately put key -> None mappings into the cache to
- # cache absence of the key, on the assumption that if we've
- # explicitly asked for some types then we will probably ask
- # for them again.
- state_dict = {key: None for key in types}
- state_dict.update(results[group])
- results[group] = state_dict
- else:
- state_dict = results[group]
-
- for event_id in state_ids:
- try:
- state_event = state_events[event_id]
- state_dict[(state_event.type, state_event.state_key)] = state_event
- except KeyError:
- # Hmm. So we do don't have that state event? Interesting.
- logger.warn(
- "Can't find state event %r for state group %r",
- event_id, group,
- )
-
- self._state_group_cache.update(
- cache_seq_num,
- key=group,
- value=state_dict,
- full=(types is None),
- )
-
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
for group, state_dict in results.items():
results[group] = {
- key: event for key, event in state_dict.items() if event
+ key: state_events[event_id]
+ for key, event_id in state_dict.items()
+ if event_id and event_id in state_events
}
defer.returnValue(results)
|