diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 64c5ae9928..19b16ed404 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -283,6 +283,9 @@ class StateStore(SQLBaseStore):
def _get_state_for_group_from_cache(self, group, types=None):
"""Checks if group is in cache. See `_get_state_for_groups`
+
+ Returns 2-tuple (`state_dict`, `missing_types`). `missing_types` is the
+ list of types that aren't in the cache for that group.
"""
is_all, state_dict = self._state_group_cache.get(group)
@@ -300,29 +303,31 @@ class StateStore(SQLBaseStore):
if (typ, state_key) not in state_dict:
missing_types.add((typ, state_key))
- if is_all and types is None:
- return state_dict, missing_types
-
- if is_all or (types is not None and not missing_types):
- sentinel = object()
+ if is_all:
+ missing_types = set()
+ if types is None:
+ return state_dict, set(), True
- def include(typ, state_key):
- valid_state_keys = type_to_key.get(typ, sentinel)
- if valid_state_keys is sentinel:
- return False
- if valid_state_keys is None:
- return True
- if state_key in valid_state_keys:
- return True
- return False
+ sentinel = object()
- return {
- k: v
- for k, v in state_dict.items()
- if v and include(k[0], k[1])
- }, missing_types
+ def include(typ, state_key):
+ if types is None:
+ return True
- return {}, missing_types
+ valid_state_keys = type_to_key.get(typ, sentinel)
+ if valid_state_keys is sentinel:
+ return False
+ if valid_state_keys is None:
+ return True
+ if state_key in valid_state_keys:
+ return True
+ return False
+
+ return {
+ k: v
+ for k, v in state_dict.items()
+ if include(k[0], k[1])
+ }, missing_types, not missing_types and types is not None
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
@@ -333,25 +338,28 @@ class StateStore(SQLBaseStore):
"""
results = {}
missing_groups_and_types = []
- for group in groups:
- state_dict, missing_types = self._get_state_for_group_from_cache(
+ for group in set(groups):
+ state_dict, missing_types, got_all = self._get_state_for_group_from_cache(
group, types
)
- if types is not None and not missing_types:
- results[group] = {
- key: value
- for key, value in state_dict.items()
- if value
- }
- else:
+ results[group] = state_dict
+
+ if not got_all:
missing_groups_and_types.append((
group,
missing_types if types else None
))
if not missing_groups_and_types:
- defer.returnValue(results)
+ defer.returnValue({
+ k: {
+ key: ev
+ for key, ev in state.items()
+ if ev
+ }
+ for k, state in results.items()
+ })
# Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence
@@ -371,10 +379,15 @@ class StateStore(SQLBaseStore):
}
for group, state_ids in group_state_dict.items():
- state_dict = {
- key: None
- for key in missing_types
- }
+ if types:
+ state_dict = {
+ key: None
+ for key in types
+ }
+ state_dict.update(results[group])
+ else:
+ state_dict = results[group]
+
evs = [
state_events[e_id] for e_id in state_ids
if e_id in state_events # This can happen if event is rejected.
@@ -392,11 +405,11 @@ class StateStore(SQLBaseStore):
full=(types is None),
)
- results[group] = {
+ results[group].update({
key: value
for key, value in state_dict.items()
if value
- }
+ })
defer.returnValue(results)
|