summary refs log tree commit diff
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <olivier@librepush.net>2021-09-07 09:06:54 +0100
committerOlivier Wilkinson (reivilibre) <olivier@librepush.net>2021-09-07 09:06:54 +0100
commitfda00e102bfec0f4b995cb4882d890e4a3d8e10f (patch)
tree10e05d837940814f268080553725b01b9668da62
parentIntroduce 'MultiKeyResponseCache' (diff)
downloadsynapse-fda00e102bfec0f4b995cb4882d890e4a3d8e10f.tar.xz
Add a multi-key response cache and search it when querying
-rw-r--r--synapse/storage/databases/state/store.py170
1 files changed, 152 insertions, 18 deletions
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 3e0d7d793e..44eb5999ca 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -16,7 +16,12 @@ import logging
 from collections import namedtuple
 from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Union
 
+import attr
+
+from twisted.internet.defer import Deferred
+
 from synapse.api.constants import EventTypes
+from synapse.logging.context import make_deferred_yieldable
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
@@ -26,9 +31,14 @@ from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import MutableStateMap, StateKey, StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.caches.multi_key_response_cache import MultiKeyResponseCache
 
 logger = logging.getLogger(__name__)
 
+InflightStateGroupCacheKey = Union[
+    Tuple[int, StateFilter], Tuple[int, str, Optional[str]]
+]
+
 
 MAX_STATE_DELTA_HOPS = 100
 
@@ -93,6 +103,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             500000,
         )
 
+        self._state_group_inflight_cache: MultiKeyResponseCache[
+            InflightStateGroupCacheKey, Dict[int, StateMap[str]]
+        ] = MultiKeyResponseCache(
+            self.hs.get_clock(),
+            "*stateGroupInflightCache*",
+            # As the results from this transaction immediately go into the
+            # immediate caches _state_group_cache and _state_group_members_cache,
+            # we do not keep them in the in-flight cache when done.
+            timeout_ms=0,
+        )
+
         def get_max_state_group_txn(txn: Cursor):
             txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
             return txn.fetchone()[0]
@@ -229,10 +250,56 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         Returns:
             Dict of state group to state map.
         """
-        state_filter = state_filter or StateFilter.all()
 
+        def try_combine_inflight_requests(
+            group: int,
+            state_filter: StateFilter,
+            mut_inflight_requests: "List[Tuple[int, Deferred[Dict[int, StateMap[str]]]]]",
+        ) -> bool:
+            """
+            Tries to collect existing in-flight requests that would give us all
+            the desired state for the given group.
+
+            Returns true if successful, or false if not.
+            If successful, adds more in-flight requests to the `mut_inflight_requests` list.
+            """
+            original_inflight_requests = len(mut_inflight_requests)
+            for event_type, state_keys in state_filter.types.items():
+                # First see if any requests are looking up ALL state keys for this
+                # event type.
+                result = self._state_group_inflight_cache.get((group, event_type, None))
+                if result is not None:
+                    inflight_requests.append((group, make_deferred_yieldable(result)))
+                    continue
+
+                if state_keys is None:
+                    # We want all state keys, but there isn't a request in-flight
+                    # wanting them all, so we have to give up here.
+                    del mut_inflight_requests[original_inflight_requests:]
+                    return False
+                else:
+                    # If we are only interested in certain state keys,
+                    # we can see if other in-flight requests would manage to
+                    # give us all the wanted state keys.
+                    for state_key in state_keys:
+                        result = self._state_group_inflight_cache.get(
+                            (group, event_type, state_key)
+                        )
+                        if result is None:
+                            # There isn't an in-flight request already requesting
+                            # this, so give up here.
+                            del mut_inflight_requests[original_inflight_requests:]
+                            return False
+
+                        inflight_requests.append(
+                            (group, make_deferred_yieldable(result))
+                        )
+            return True
+
+        state_filter = state_filter or StateFilter.all()
         member_filter, non_member_filter = state_filter.get_member_split()
 
+        # QUERY THE IMMEDIATE CACHES
         # Now we look them up in the member and non-member caches
         (
             non_member_state,
@@ -249,37 +316,104 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         for group in groups:
             state[group].update(member_state[group])
 
-        # Now fetch any missing groups from the database
-
         incomplete_groups = incomplete_groups_m | incomplete_groups_nm
 
         if not incomplete_groups:
             return state
 
-        cache_sequence_nm = self._state_group_cache.sequence
-        cache_sequence_m = self._state_group_members_cache.sequence
+        # QUERY THE IN-FLIGHT CACHE
+        # list (group ID -> Deferred that will contain a result for that group)
+        inflight_requests: List[Tuple[int, Deferred[Dict[int, StateMap[str]]]]] = []
+        inflight_cache_misses: List[int] = []
 
-        # Help the cache hit ratio by expanding the filter a bit
+        # When we get around to requesting state from the database, we help the
+        # cache hit ratio by expanding the filter a bit.
+        # However, we need to know this now so that we can properly query the
+        # in-flight cache where include_others is concerned.
         db_state_filter = state_filter.return_expanded()
 
-        group_to_state_dict = await self._get_state_groups_from_groups(
-            list(incomplete_groups), state_filter=db_state_filter
-        )
+        for group in incomplete_groups:
+            event_type: str
+            state_keys: Optional[FrozenSet[str]]
+
+            # First check if our exact state filter is being looked up.
+            result = self._state_group_inflight_cache.get((group, db_state_filter))
+            if result is not None:
+                inflight_requests.append((group, make_deferred_yieldable(result)))
+                continue
+
+            # Then check if the universal state filter is being looked up.
+            result = self._state_group_inflight_cache.get((group, StateFilter.all()))
+            if result is not None:
+                inflight_requests.append((group, make_deferred_yieldable(result)))
+                continue
+
+            if state_filter.include_others:
+                # if the state filter includes others, we only match against the
+                # state filter directly, so we give up here.
+                # This is because it's too complex to cache this case properly.
+                inflight_cache_misses.append(group)
+                continue
+            elif not db_state_filter.include_others:
+                # Try looking to see if the same filter but with include_others
+                # is being looked up.
+                result = self._state_group_inflight_cache.get(
+                    (group, attr.evolve(db_state_filter, include_others=True))
+                )
+                if result is not None:
+                    inflight_requests.append((group, make_deferred_yieldable(result)))
+                    continue
+
+            if try_combine_inflight_requests(
+                group, state_filter, inflight_requests
+            ):
+                # succeeded in finding in-flight requests that could be combined
+                # together to give all the state we need for this group.
+                continue
+
+            inflight_cache_misses.append(group)
+
+        # SERVICE CACHE MISSES
+        if inflight_cache_misses:
+            cache_sequence_nm = self._state_group_cache.sequence
+            cache_sequence_m = self._state_group_members_cache.sequence
+
+            async def get_state_groups_from_groups_then_add_to_cache() -> Dict[
+                int, StateMap[str]
+            ]:
+                groups_to_state_dict = await self._get_state_groups_from_groups(
+                    list(inflight_cache_misses), state_filter=db_state_filter
+                )
 
-        # Now lets update the caches
-        self._insert_into_cache(
-            group_to_state_dict,
-            db_state_filter,
-            cache_seq_num_members=cache_sequence_m,
-            cache_seq_num_non_members=cache_sequence_nm,
-        )
+                # Now let's update the caches.
+                self._insert_into_cache(
+                    groups_to_state_dict,
+                    db_state_filter,
+                    cache_seq_num_members=cache_sequence_m,
+                    cache_seq_num_non_members=cache_sequence_nm,
+                )
+
+                return groups_to_state_dict
 
+            keys = ()  # TODO populate with keys
+            spawned_request = self._state_group_inflight_cache.set_and_compute(
+                tuple(keys), get_state_groups_from_groups_then_add_to_cache
+            )
+            for group in inflight_cache_misses:
+                inflight_requests.append((group, spawned_request))
+
+        # WAIT FOR IN-FLIGHT REQUESTS TO FINISH
+        for group, inflight_request in inflight_requests:
+            request_result = await inflight_request
+            state[group].update(request_result[group])
+
+        # ASSEMBLE
         # And finally update the result dict, by filtering out any extra
         # stuff we pulled out of the database.
-        for group, group_state_dict in group_to_state_dict.items():
+        for group in groups:
             # We just replace any existing entries, as we will have loaded
             # everything we need from the database anyway.
-            state[group] = state_filter.filter_state(group_state_dict)
+            state[group] = state_filter.filter_state(state[group])
 
         return state