summary refs log tree commit diff
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <olivier@librepush.net>2021-08-23 18:52:56 +0100
committerOlivier Wilkinson (reivilibre) <olivier@librepush.net>2021-08-24 10:43:09 +0100
commit42da29594f1d9b84fa93e752adca08a3a8c2eea1 (patch)
treeeefd0aed406d2103fc29e200d2b8176b772c1d7a
parentType annotations (diff)
downloadsynapse-42da29594f1d9b84fa93e752adca08a3a8c2eea1.tar.xz
checkpoint: mostly what I wanted here
-rw-r--r--synapse/storage/databases/state/store.py206
-rw-r--r--synapse/util/caches/multi_key_response_cache.py210
2 files changed, 391 insertions, 25 deletions
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index f839c0c24f..8b73c0a92b 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -14,9 +14,14 @@
 
 import logging
 from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Union
+
+import attr
+
+from twisted.internet.defer import Deferred
 
 from synapse.api.constants import EventTypes
+from synapse.logging.context import make_deferred_yieldable
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
@@ -26,9 +31,12 @@ from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import MutableStateMap, StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.caches.multi_key_response_cache import MultiKeyResponseCache
 
 logger = logging.getLogger(__name__)
 
+# XXX
+UNKNOWN = Any  # TODO
 
 MAX_STATE_DELTA_HOPS = 100
 
@@ -91,6 +99,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             500000,
         )
 
+        # XXX ADD TYPE
+        self._state_group_inflight_cache: MultiKeyResponseCache[
+            ...
+        ] = MultiKeyResponseCache(
+            self.hs.get_clock(),
+            "*stateGroupInflightCache*",
+            # As the results from this transaction immediately go into the
+            # immediate caches _state_group_cache and _state_group_members_cache,
+            # we do not keep them in the in-flight cache when done.
+            timeout_ms=0,
+        )
+
         def get_max_state_group_txn(txn: Cursor):
             txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
             return txn.fetchone()[0]
@@ -168,13 +188,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return results
 
-    def _get_state_for_group_using_cache(self, cache, group, state_filter):
+    def _get_state_for_group_using_cache(
+        self,
+        cache: DictionaryCache[int, UNKNOWN],
+        group: int,
+        state_filter: StateFilter,
+    ) -> Tuple[MutableStateMap[UNKNOWN], bool]:
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Args:
-            cache(DictionaryCache): the state group cache to use
-            group(int): The state group to lookup
-            state_filter (StateFilter): The state filter used to fetch state
+            cache: the state group cache to use
+            group: The state group to lookup
+            state_filter: The state filter used to fetch state
                 from the database.
 
         Returns 2-tuple (`state_dict`, `got_all`).
@@ -212,7 +237,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Dict[int, MutableStateMap[str]]:
         """Gets the state at each of a list of state groups, optionally
-        filtering by type/state_key
+        filtering by type/state_key.
 
         Args:
             groups: list of state groups for which we want
@@ -221,11 +246,38 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 from the database.
         Returns:
             Dict of state group to state map.
+
+
+        The flow for this function looks as follows:
+
+                * Query the immediate caches (self._state_group_cache,
+                |                             self._state_group_members_cache).
+        NONSTOP |
+                |
+                * Query the in-flight cache (self._state_group_inflight_cache)
+                | for immediate-cache misses.
+        NONSTOP |
+                |
+                * Service cache misses:
+                |   - Expand the state filter (to help cache hit ratio).
+                |   - Start a new transaction to fetch outstanding groups.
+                |   - Register entries in the in-flight cache for this transaction.
+                |   - (When the transaction is finished) Register entries in
+                |     the immediate caches.
+                |
+                * Wait for in-flight requests to finish...
+                |
+                * Assemble everything together and filter out anything we didn't
+                  ask for.
+
+        The sections marked NONSTOP must not contain any `await`s, otherwise
+        race conditions could occur and the cache could be made less effective.
         """
-        state_filter = state_filter or StateFilter.all()
 
+        state_filter = state_filter or StateFilter.all()
         member_filter, non_member_filter = state_filter.get_member_split()
 
+        # QUERY THE IMMEDIATE CACHES
         # Now we look them up in the member and non-member caches
         (
             non_member_state,
@@ -242,43 +294,147 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         for group in groups:
             state[group].update(member_state[group])
 
-        # Now fetch any missing groups from the database
-
         incomplete_groups = incomplete_groups_m | incomplete_groups_nm
 
         if not incomplete_groups:
             return state
 
-        cache_sequence_nm = self._state_group_cache.sequence
-        cache_sequence_m = self._state_group_members_cache.sequence
+        # QUERY THE IN-FLIGHT CACHE
+        # list (group ID -> Deferred that will contain a result for that group)
+        inflight_requests: List[Tuple[int, Deferred[Dict[int, StateMap[str]]]]] = []
+        inflight_cache_misses: List[int] = []
 
-        # Help the cache hit ratio by expanding the filter a bit
+        # When we get around to requesting state from the database, we help the
+        # cache hit ratio by expanding the filter a bit.
+        # However, we need to know this now so that we can properly query the
+        # in-flight cache where include_others is concerned.
         db_state_filter = state_filter.return_expanded()
 
-        group_to_state_dict = await self._get_state_groups_from_groups(
-            list(incomplete_groups), state_filter=db_state_filter
-        )
+        for group in incomplete_groups:
+            event_type: str
+            state_keys: Optional[FrozenSet[str]]
+
+            # First check if our exact state filter is being looked up.
+            result = self._state_group_inflight_cache.get((group, db_state_filter))
+            if result is not None:
+                inflight_requests.append((group, make_deferred_yieldable(result)))
+                continue
+
+            # Then check if the universal state filter is being looked up.
+            result = self._state_group_inflight_cache.get((group, StateFilter.all()))
+            if result is not None:
+                inflight_requests.append((group, make_deferred_yieldable(result)))
+                continue
+
+            if state_filter.include_others:
+                # if the state filter includes others, we only match against the
+                # state filter directly, so we give up here.
+                # This is because it's too complex to cache this case properly.
+                inflight_cache_misses.append(group)
+                continue
+            elif not db_state_filter.include_others:
+                # TODO IS THIS USEFUL
+                # Try looking to see if the same filter but with include_others
+                # is being looked up.
+                result = self._state_group_inflight_cache.get(
+                    (group, attr.evolve(db_state_filter, include_others=True))
+                )
+                if result is not None:
+                    inflight_requests.append((group, make_deferred_yieldable(result)))
+                    continue
+
+            for event_type, state_keys in state_filter.types.items():
+                result = self._state_group_inflight_cache.get((group, event_type, None))
+                if result is not None:
+                    inflight_requests.append((group, make_deferred_yieldable(result)))
+                    continue
+
+                if state_keys is not None:
+                    got_all_state_keys = False
+                    for state_key in state_keys:
+                        result = self._state_group_inflight_cache.get(
+                            (group, event_type, state_key)
+                        )
+                        if result is not None:
+                            inflight_requests.append(
+                                (group, make_deferred_yieldable(result))
+                            )
+                        else:
+                            break
+                    else:
+                        got_all_state_keys = True
+
+                    if not got_all_state_keys:
+                        # we still have to request against this group.
+                        inflight_cache_misses.append(group)
+                        break
+
+        # SERVICE CACHE MISSES
+        if inflight_cache_misses:
+            cache_sequence_nm = self._state_group_cache.sequence
+            cache_sequence_m = self._state_group_members_cache.sequence
+
+            async def get_state_groups_from_groups_then_add_to_cache() -> Dict[
+                int, StateMap[str]
+            ]:
+                groups_to_state_dict = await self._get_state_groups_from_groups(
+                    list(inflight_cache_misses), state_filter=db_state_filter
+                )
 
-        # Now lets update the caches
-        self._insert_into_cache(
-            group_to_state_dict,
-            db_state_filter,
-            cache_seq_num_members=cache_sequence_m,
-            cache_seq_num_non_members=cache_sequence_nm,
-        )
+                # Now let's update the caches.
+                self._insert_into_cache(
+                    groups_to_state_dict,
+                    db_state_filter,
+                    cache_seq_num_members=cache_sequence_m,
+                    cache_seq_num_non_members=cache_sequence_nm,
+                )
 
+                return groups_to_state_dict
+
+            # make a list of keys for us to store in the in-flight cache
+            # this should list all the keys that the request will pick up from
+            # the database.
+            keys: List[
+                Union[Tuple[int, StateFilter], Tuple[int, str, Optional[str]]]
+            ] = []
+            for group in inflight_cache_misses:
+                if db_state_filter.include_others:
+                    # we can't intelligently cache include_others under any other keys
+                    # because we don't know what keys are included.
+                    keys.append((group, db_state_filter))
+                    continue
+
+                for event_type, state_keys in db_state_filter.types.items():
+                    if state_keys is None:
+                        keys.append((group, event_type, None))
+                    else:
+                        for state_key in state_keys:
+                            keys.append((group, event_type, state_key))
+
+            spawned_request = self._state_group_inflight_cache.set_and_compute(
+                tuple(keys), get_state_groups_from_groups_then_add_to_cache
+            )
+            for group in inflight_cache_misses:
+                inflight_requests.append((group, spawned_request))
+
+        # WAIT FOR IN-FLIGHT REQUESTS TO FINISH
+        for group, inflight_request in inflight_requests:
+            request_result = await inflight_request
+            state[group].update(request_result[group])
+
+        # ASSEMBLE
         # And finally update the result dict, by filtering out any extra
         # stuff we pulled out of the database.
-        for group, group_state_dict in group_to_state_dict.items():
+        for group in groups:
             # We just replace any existing entries, as we will have loaded
             # everything we need from the database anyway.
-            state[group] = state_filter.filter_state(group_state_dict)
+            state[group] = state_filter.filter_state(state[group])
 
         return state
 
     def _get_state_for_groups_using_cache(
         self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
-    ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+    ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key, querying from a specific cache.
 
diff --git a/synapse/util/caches/multi_key_response_cache.py b/synapse/util/caches/multi_key_response_cache.py
new file mode 100644
index 0000000000..1e35c33a4b
--- /dev/null
+++ b/synapse/util/caches/multi_key_response_cache.py
@@ -0,0 +1,210 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Any, Awaitable, Callable, Dict, Generic, Optional, Tuple, TypeVar
+
+import attr
+
+from twisted.internet import defer
+
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.util import Clock
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.caches import register_cache
+
+logger = logging.getLogger(__name__)
+
+# the type of the key in the cache
+KV = TypeVar("KV")
+
+# the type of the result from the operation
+RV = TypeVar("RV")
+
+
+@attr.s(auto_attribs=True)
+class MultiKeyResponseCacheContext(Generic[KV]):
+    """Information about a missed MultiKeyResponseCache hit
+
+    This object can be passed into the callback for additional feedback
+    """
+
+    cache_keys: Tuple[KV, ...]
+    """The cache key that caused the cache miss
+
+    This should be considered read-only.
+
+    TODO: in attrs 20.1, make it frozen with an on_setattr.
+    """
+
+    should_cache: bool = True
+    """Whether the result should be cached once the request completes.
+
+    This can be modified by the callback if it decides its result should not be cached.
+    """
+
+
+class MultiKeyResponseCache(Generic[KV]):
+    """
+    This caches a deferred response. Until the deferred completes it will be
+    returned from the cache. This means that if the client retries the request
+    while the response is still being computed, that original response will be
+    used rather than trying to compute a new response.
+
+    Unlike the plain ResponseCache, this cache admits multiple keys to the
+    deferred response.
+    """
+
+    def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
+        # This is poorly-named: it includes both complete and incomplete results.
+        # We keep complete results rather than switching to absolute values because
+        # that makes it easier to cache Failure results.
+        self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
+
+        self.clock = clock
+        self.timeout_sec = timeout_ms / 1000.0
+
+        self._name = name
+        self._metrics = register_cache(
+            "multikey_response_cache", name, self, resizable=False
+        )
+
+    def size(self) -> int:
+        return len(self.pending_result_cache)
+
+    def __len__(self) -> int:
+        return self.size()
+
+    def get(self, key: KV) -> Optional[defer.Deferred]:
+        """Look up the given key.
+
+        Returns a new Deferred (which also doesn't follow the synapse
+        logcontext rules). You will probably want to make_deferred_yieldable the result.
+
+        If there is no entry for the key, returns None.
+
+        Args:
+            key: key to get/set in the cache
+
+        Returns:
+            None if there is no entry for this key; otherwise a deferred which
+            resolves to the result.
+        """
+        result = self.pending_result_cache.get(key)
+        if result is not None:
+            self._metrics.inc_hits()
+            return result.observe()
+        else:
+            self._metrics.inc_misses()
+            return None
+
+    def _set(
+        self, context: MultiKeyResponseCacheContext[KV], deferred: defer.Deferred
+    ) -> defer.Deferred:
+        """Set the entry for the given key to the given deferred.
+
+        *deferred* should run its callbacks in the sentinel logcontext (ie,
+        you should wrap normal synapse deferreds with
+        synapse.logging.context.run_in_background).
+
+        Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
+        You will probably want to make_deferred_yieldable the result.
+
+        Args:
+            context: Information about the cache miss
+            deferred: The deferred which resolves to the result.
+
+        Returns:
+            A new deferred which resolves to the actual result.
+        """
+        result = ObservableDeferred(deferred, consumeErrors=True)
+        keys = context.cache_keys
+        for key in keys:
+            if key not in self.pending_result_cache:
+                # we only add the key if it's not already there, since we assume
+                # that we won't overtake prior entries.
+                self.pending_result_cache[key] = result
+
+        def on_complete(r):
+            # if this cache has a non-zero timeout, and the callback has not cleared
+            # the should_cache bit, we leave it in the cache for now and schedule
+            # its removal later.
+            if self.timeout_sec and context.should_cache:
+                for key in keys:
+                    # TODO sketch, should do this in only one call_later.
+                    self.clock.call_later(
+                        self.timeout_sec, self.pending_result_cache.pop, key, None
+                    )
+            else:
+                for key in keys:
+                    # otherwise, remove the result immediately.
+                    self.pending_result_cache.pop(key, None)
+            return r
+
+        # make sure we do this *after* adding the entry to pending_result_cache,
+        # in case the result is already complete (in which case flipping the order would
+        # leave us with a stuck entry in the cache).
+        result.addBoth(on_complete)
+        return result.observe()
+
+    def set_and_compute(
+        self,
+        keys: Tuple[KV, ...],
+        callback: Callable[..., Awaitable[RV]],
+        *args: Any,
+        cache_context: bool = False,
+        **kwargs: Any,
+    ) -> defer.Deferred[RV]:
+        """Perform a *set* call, taking care of logcontexts
+
+        Makes a call to *callback(*args, **kwargs)*, which should
+        follow the synapse logcontext rules, and adds the result to the cache.
+
+        Example usage:
+
+            async def handle_request(request):
+                # etc
+                return result
+
+            result = await response_cache.wrap(
+                key,
+                handle_request,
+                request,
+            )
+
+        Args:
+            keys: keys to get/set in the cache
+
+            callback: function to call
+
+            *args: positional parameters to pass to the callback, if it is used
+
+            cache_context: if set, the callback will be given a `cache_context` kw arg,
+                which will be a ResponseCacheContext object.
+
+            **kwargs: named parameters to pass to the callback, if it is used
+
+        Returns:
+            The result of the callback (from the cache, or otherwise)
+        """
+
+        # TODO sketch logger.debug(
+        #     "[%s]: no cached result for [%s], calculating new one", self._name, key
+        # )
+        context = MultiKeyResponseCacheContext(cache_keys=keys)
+        if cache_context:
+            kwargs["cache_context"] = context
+        d = run_in_background(callback, *args, **kwargs)
+        result = self._set(context, d)
+
+        return make_deferred_yieldable(result)