diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 79c3b82d9f..1293842361 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -287,42 +287,39 @@ class StateStore(SQLBaseStore):
f,
)
- def _get_state_for_group_from_cache(self, group, types=None):
+ def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
`missing_types` is the list of types that aren't in the cache for that
- group, or None if `types` is None. `got_all` is a bool indicating if
- we successfully retrieved all requests state from the cache, if False
- we need to query the DB for the missing state.
+ group. `got_all` is a bool indicating if we successfully retrieved all
+ requests state from the cache, if False we need to query the DB for the
+ missing state.
+
+ Args:
+ group: The state group to lookup
+ types (list): List of 2-tuples of the form (`type`, `state_key`),
+ where a `state_key` of `None` matches all state_keys for the
+ `type`.
"""
is_all, state_dict = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
- if types is not None:
- for typ, state_key in types:
- if state_key is None:
- type_to_key[typ] = None
- missing_types.add((typ, state_key))
- else:
- if type_to_key.get(typ, object()) is not None:
- type_to_key.setdefault(typ, set()).add(state_key)
-
- if (typ, state_key) not in state_dict:
- missing_types.add((typ, state_key))
+ for typ, state_key in types:
+ if state_key is None:
+ type_to_key[typ] = None
+ missing_types.add((typ, state_key))
+ else:
+ if type_to_key.get(typ, object()) is not None:
+ type_to_key.setdefault(typ, set()).add(state_key)
- if is_all:
- missing_types = set()
- if types is None:
- return state_dict, set(), True
+ if (typ, state_key) not in state_dict:
+ missing_types.add((typ, state_key))
sentinel = object()
def include(typ, state_key):
- if types is None:
- return True
-
valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel:
return False
@@ -340,6 +337,19 @@ class StateStore(SQLBaseStore):
if include(k[0], k[1])
}, missing_types, got_all
+ def _get_all_state_from_cache(self, group):
+ """Checks if group is in cache. See `_get_state_for_groups`
+
+ Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
+ indicating if we successfully retrieved all requests state from the
+ cache, if False we need to query the DB for the missing state.
+
+ Args:
+ group: The state group to lookup
+ """
+ is_all, state_dict = self._state_group_cache.get(group)
+ return state_dict, is_all
+
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
"""Given list of groups returns dict of group -> list of state events
@@ -349,18 +359,29 @@ class StateStore(SQLBaseStore):
"""
results = {}
missing_groups_and_types = []
- for group in set(groups):
- state_dict, missing_types, got_all = self._get_state_for_group_from_cache(
- group, types
- )
+ if types is not None:
+ for group in set(groups):
+ state_dict, missing_types, got_all = self._get_some_state_from_cache(
+ group, types
+ )
+
+ results[group] = state_dict
+
+ if not got_all:
+ missing_groups_and_types.append((
+ group,
+ missing_types
+ ))
+ else:
+ for group in set(groups):
+ state_dict, got_all = self._get_all_state_from_cache(
+ group
+ )
- results[group] = state_dict
+ results[group] = state_dict
- if not got_all:
- missing_groups_and_types.append((
- group,
- missing_types if types else None
- ))
+ if not got_all:
+ missing_groups_and_types.append((group, None))
if not missing_groups_and_types:
defer.returnValue({
|