diff options
-rw-r--r-- | synapse/storage/state.py | 85 |
1 files changed, 53 insertions, 32 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 79c3b82d9f..1293842361 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -287,42 +287,39 @@ class StateStore(SQLBaseStore): f, ) - def _get_state_for_group_from_cache(self, group, types=None): + def _get_some_state_from_cache(self, group, types): """Checks if group is in cache. See `_get_state_for_groups` Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). `missing_types` is the list of types that aren't in the cache for that - group, or None if `types` is None. `got_all` is a bool indicating if - we successfully retrieved all requests state from the cache, if False - we need to query the DB for the missing state. + group. `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + + Args: + group: The state group to lookup + types (list): List of 2-tuples of the form (`type`, `state_key`), + where a `state_key` of `None` matches all state_keys for the + `type`. """ is_all, state_dict = self._state_group_cache.get(group) type_to_key = {} missing_types = set() - if types is not None: - for typ, state_key in types: - if state_key is None: - type_to_key[typ] = None - missing_types.add((typ, state_key)) - else: - if type_to_key.get(typ, object()) is not None: - type_to_key.setdefault(typ, set()).add(state_key) - - if (typ, state_key) not in state_dict: - missing_types.add((typ, state_key)) + for typ, state_key in types: + if state_key is None: + type_to_key[typ] = None + missing_types.add((typ, state_key)) + else: + if type_to_key.get(typ, object()) is not None: + type_to_key.setdefault(typ, set()).add(state_key) - if is_all: - missing_types = set() - if types is None: - return state_dict, set(), True + if (typ, state_key) not in state_dict: + missing_types.add((typ, state_key)) sentinel = object() def include(typ, state_key): - if types is None: - return True - valid_state_keys = type_to_key.get(typ, sentinel) if valid_state_keys is sentinel: return False @@ -340,6 +337,19 @@ class StateStore(SQLBaseStore): if include(k[0], k[1]) }, missing_types, got_all + def _get_all_state_from_cache(self, group): + """Checks if group is in cache. See `_get_state_for_groups` + + Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool + indicating if we successfully retrieved all requests state from the + cache, if False we need to query the DB for the missing state. + + Args: + group: The state group to lookup + """ + is_all, state_dict = self._state_group_cache.get(group) + return state_dict, is_all + @defer.inlineCallbacks def _get_state_for_groups(self, groups, types=None): """Given list of groups returns dict of group -> list of state events @@ -349,18 +359,29 @@ class StateStore(SQLBaseStore): """ results = {} missing_groups_and_types = [] - for group in set(groups): - state_dict, missing_types, got_all = self._get_state_for_group_from_cache( - group, types - ) + if types is not None: + for group in set(groups): + state_dict, missing_types, got_all = self._get_some_state_from_cache( + group, types + ) + + results[group] = state_dict + + if not got_all: + missing_groups_and_types.append(( + group, + missing_types + )) + else: + for group in set(groups): + state_dict, got_all = self._get_all_state_from_cache( + group + ) - results[group] = state_dict + results[group] = state_dict - if not got_all: - missing_groups_and_types.append(( - group, - missing_types if types else None - )) + if not got_all: + missing_groups_and_types.append((group, None)) if not missing_groups_and_types: defer.returnValue({ |