diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b452813fbb..c5ff44fef7 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import namedtuple
import logging
+from collections import namedtuple
from six import iteritems, itervalues
from six.moves import range
@@ -23,10 +23,11 @@ from twisted.internet import defer
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine
-from synapse.util.caches import intern_string, get_cache_factor_for
+from synapse.util.caches import get_cache_factor_for, intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
+
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -585,19 +586,24 @@ class StateGroupWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
- """Given list of groups returns dict of group -> list of state events
- with matching types.
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
Args:
- groups(list[int]): list of groups whose state to query
- types(list[str|None, str|None]|None): List of 2-tuples of the form
- (`type`, `state_key`), where a `state_key` of `None` matches all
- state_keys for the `type`. Presence of type of `None` indicates
- that types not in the list should not be filtered out. If None,
- all events are returned.
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ types (None|iterable[(None|str, None|str)]):
+ indicates the state type/keys required. If None, the whole
+ state is fetched and returned.
+
+ Otherwise, each entry should be a `(type, state_key)` tuple to
+ include in the response. A `state_key` of None is a wildcard
+ meaning that we require all state with that type. A `type` of None
+ indicates that types not in the list should not be filtered out.
Returns:
- dict of group -> list of state events
+ Deferred[dict[int, dict[(type, state_key), EventBase]]]
+ a dictionary mapping from state group to state dictionary.
"""
if types:
types = frozenset(types)
@@ -606,7 +612,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if types is not None:
for group in set(groups):
state_dict_ids, _, got_all = self._get_some_state_from_cache(
- group, types
+ group, types,
)
results[group] = state_dict_ids
@@ -627,26 +633,40 @@ class StateGroupWorkerStore(SQLBaseStore):
# Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence
+ # the DictionaryCache knows if it has *all* the state, but
+ # does not know if it has all of the keys of a particular type,
+ # which makes wildcard lookups expensive unless we have a complete
+ # cache. Hence, if we are doing a wildcard lookup, populate the
+ # cache fully so that we can do an efficient lookup next time.
+
+ if types and any(k is None for (t, k) in types):
+ types_to_fetch = None
+ else:
+ types_to_fetch = types
+
group_to_state_dict = yield self._get_state_groups_from_groups(
- missing_groups, types
+ missing_groups, types_to_fetch,
)
- # Now we want to update the cache with all the things we fetched
- # from the database.
for group, group_state_dict in iteritems(group_to_state_dict):
state_dict = results[group]
- state_dict.update(
- ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
- for k, v in iteritems(group_state_dict)
- )
-
+ # update the result, filtering by `types`.
+ if types:
+ for k, v in iteritems(group_state_dict):
+ (typ, _) = k
+ if k in types or (typ, None) in types:
+ state_dict[k] = v
+ else:
+ state_dict.update(group_state_dict)
+
+ # update the cache with all the things we fetched from the
+ # database.
self._state_group_cache.update(
cache_seq_num,
key=group,
- value=state_dict,
- full=(types is None),
- known_absent=types,
+ value=group_state_dict,
+ fetched_keys=types_to_fetch,
)
defer.returnValue(results)
@@ -753,7 +773,6 @@ class StateGroupWorkerStore(SQLBaseStore):
self._state_group_cache.sequence,
key=state_group,
value=dict(current_state_ids),
- full=True,
)
return state_group
|