diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f839c0c24f..e68a409597 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -16,6 +16,8 @@ import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Set, Tuple
+from frozendict import frozendict
+
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -26,6 +28,7 @@ from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.caches.response_cache import ResponseCache
logger = logging.getLogger(__name__)
@@ -91,6 +94,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
500000,
)
+ self._state_group_from_group_cache = ResponseCache(
+ self.hs.get_clock(),
+ # REVIEW: why do the other 2 have asterisks? should this one too?
+ "*stateGroupFromGroupCache*",
+ # TODO: not tuned
+ timeout_ms=30_000,
+ )
+
def get_max_state_group_txn(txn: Cursor):
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0]
@@ -156,19 +167,50 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""
results = {}
- chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
- for chunk in chunks:
- res = await self.db_pool.runInteraction(
- "_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn,
- chunk,
- state_filter,
+ for group in groups:
+ results[group] = await self._get_state_groups_from_group(
+ group, state_filter
)
- results.update(res)
return results
- def _get_state_for_group_using_cache(self, cache, group, state_filter):
+ async def _get_state_groups_from_group(
+ self, group: int, state_filter: StateFilter
+ ) -> StateMap[str]:
+ """Returns the state groups for a given group from the
+ database, filtering on types of state events.
+
+ Args:
+ group: state group ID to query
+ state_filter: The state filter used to fetch state
+ from the database.
+ Returns:
+ state map
+ """
+
+ # convert the state_filter.types dict into something that is hashable.
+ frozen_kvs = {}
+ for k, v in state_filter.types.items():
+ if v is None:
+ frozen_kvs[k] = v
+ else:
+ # make the set hashable by making a frozen copy of it
+ frozen_kvs[k] = frozenset(v)
+
+ state_filter_hashable = (frozendict(frozen_kvs), state_filter.include_others)
+
+ return await self._state_group_from_group_cache.wrap(
+ (group, state_filter_hashable),
+ self.db_pool.runInteraction,
+ "_get_state_groups_from_group",
+ self._get_state_groups_from_group_txn,
+ group,
+ state_filter,
+ )
+
+ def _get_state_for_group_using_cache(
+ self, cache: DictionaryCache, group: int, state_filter: StateFilter
+ ) -> Tuple[StateMap, bool]:
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
|