diff options
author | Olivier Wilkinson (reivilibre) <olivier@librepush.net> | 2021-08-23 18:52:56 +0100 |
---|---|---|
committer | Olivier Wilkinson (reivilibre) <olivier@librepush.net> | 2021-08-24 10:43:09 +0100 |
commit | 42da29594f1d9b84fa93e752adca08a3a8c2eea1 (patch) | |
tree | eefd0aed406d2103fc29e200d2b8176b772c1d7a | |
parent | Type annotations (diff) | |
download | synapse-42da29594f1d9b84fa93e752adca08a3a8c2eea1.tar.xz |
checkpoint: mostly what I wanted here
-rw-r--r-- | synapse/storage/databases/state/store.py | 206 | ||||
-rw-r--r-- | synapse/util/caches/multi_key_response_cache.py | 210 |
2 files changed, 391 insertions, 25 deletions
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f839c0c24f..8b73c0a92b 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -14,9 +14,14 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Union + +import attr + +from twisted.internet.defer import Deferred from synapse.api.constants import EventTypes +from synapse.logging.context import make_deferred_yieldable from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore @@ -26,9 +31,12 @@ from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateMap from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.caches.multi_key_response_cache import MultiKeyResponseCache logger = logging.getLogger(__name__) +# XXX +UNKNOWN = Any # TODO MAX_STATE_DELTA_HOPS = 100 @@ -91,6 +99,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): 500000, ) + # XXX ADD TYPE + self._state_group_inflight_cache: MultiKeyResponseCache[ + ... + ] = MultiKeyResponseCache( + self.hs.get_clock(), + "*stateGroupInflightCache*", + # As the results from this transaction immediately go into the + # immediate caches _state_group_cache and _state_group_members_cache, + # we do not keep them in the in-flight cache when done. + timeout_ms=0, + ) + def get_max_state_group_txn(txn: Cursor): txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") return txn.fetchone()[0] @@ -168,13 +188,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return results - def _get_state_for_group_using_cache(self, cache, group, state_filter): + def _get_state_for_group_using_cache( + self, + cache: DictionaryCache[int, UNKNOWN], + group: int, + state_filter: StateFilter, + ) -> Tuple[MutableStateMap[UNKNOWN], bool]: """Checks if group is in cache. See `_get_state_for_groups` Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state + cache: the state group cache to use + group: The state group to lookup + state_filter: The state filter used to fetch state from the database. Returns 2-tuple (`state_dict`, `got_all`). @@ -212,7 +237,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: """Gets the state at each of a list of state groups, optionally - filtering by type/state_key + filtering by type/state_key. Args: groups: list of state groups for which we want @@ -221,11 +246,38 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): from the database. Returns: Dict of state group to state map. + + + The flow for this function looks as follows: + + * Query the immediate caches (self._state_group_cache, + | self._state_group_members_cache). + NONSTOP | + | + * Query the in-flight cache (self._state_group_inflight_cache) + | for immediate-cache misses. + NONSTOP | + | + * Service cache misses: + | - Expand the state filter (to help cache hit ratio). + | - Start a new transaction to fetch outstanding groups. + | - Register entries in the in-flight cache for this transaction. + | - (When the transaction is finished) Register entries in + | the immediate caches. + | + * Wait for in-flight requests to finish... + | + * Assemble everything together and filter out anything we didn't + ask for. + + The sections marked NONSTOP must not contain any `await`s, otherwise + race conditions could occur and the cache could be made less effective. """ - state_filter = state_filter or StateFilter.all() + state_filter = state_filter or StateFilter.all() member_filter, non_member_filter = state_filter.get_member_split() + # QUERY THE IMMEDIATE CACHES # Now we look them up in the member and non-member caches ( non_member_state, @@ -242,43 +294,147 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): for group in groups: state[group].update(member_state[group]) - # Now fetch any missing groups from the database - incomplete_groups = incomplete_groups_m | incomplete_groups_nm if not incomplete_groups: return state - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence + # QUERY THE IN-FLIGHT CACHE + # list (group ID -> Deferred that will contain a result for that group) + inflight_requests: List[Tuple[int, Deferred[Dict[int, StateMap[str]]]]] = [] + inflight_cache_misses: List[int] = [] - # Help the cache hit ratio by expanding the filter a bit + # When we get around to requesting state from the database, we help the + # cache hit ratio by expanding the filter a bit. + # However, we need to know this now so that we can properly query the + # in-flight cache where include_others is concerned. db_state_filter = state_filter.return_expanded() - group_to_state_dict = await self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) + for group in incomplete_groups: + event_type: str + state_keys: Optional[FrozenSet[str]] + + # First check if our exact state filter is being looked up. + result = self._state_group_inflight_cache.get((group, db_state_filter)) + if result is not None: + inflight_requests.append((group, make_deferred_yieldable(result))) + continue + + # Then check if the universal state filter is being looked up. + result = self._state_group_inflight_cache.get((group, StateFilter.all())) + if result is not None: + inflight_requests.append((group, make_deferred_yieldable(result))) + continue + + if state_filter.include_others: + # if the state filter includes others, we only match against the + # state filter directly, so we give up here. + # This is because it's too complex to cache this case properly. + inflight_cache_misses.append(group) + continue + elif not db_state_filter.include_others: + # TODO IS THIS USEFUL + # Try looking to see if the same filter but with include_others + # is being looked up. + result = self._state_group_inflight_cache.get( + (group, attr.evolve(db_state_filter, include_others=True)) + ) + if result is not None: + inflight_requests.append((group, make_deferred_yieldable(result))) + continue + + for event_type, state_keys in state_filter.types.items(): + result = self._state_group_inflight_cache.get((group, event_type, None)) + if result is not None: + inflight_requests.append((group, make_deferred_yieldable(result))) + continue + + if state_keys is not None: + got_all_state_keys = False + for state_key in state_keys: + result = self._state_group_inflight_cache.get( + (group, event_type, state_key) + ) + if result is not None: + inflight_requests.append( + (group, make_deferred_yieldable(result)) + ) + else: + break + else: + got_all_state_keys = True + + if not got_all_state_keys: + # we still have to request against this group. + inflight_cache_misses.append(group) + break + + # SERVICE CACHE MISSES + if inflight_cache_misses: + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + async def get_state_groups_from_groups_then_add_to_cache() -> Dict[ + int, StateMap[str] + ]: + groups_to_state_dict = await self._get_state_groups_from_groups( + list(inflight_cache_misses), state_filter=db_state_filter + ) - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) + # Now let's update the caches. + self._insert_into_cache( + groups_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + return groups_to_state_dict + + # make a list of keys for us to store in the in-flight cache + # this should list all the keys that the request will pick up from + # the database. + keys: List[ + Union[Tuple[int, StateFilter], Tuple[int, str, Optional[str]]] + ] = [] + for group in inflight_cache_misses: + if db_state_filter.include_others: + # we can't intelligently cache include_others under any other keys + # because we don't know what keys are included. + keys.append((group, db_state_filter)) + continue + + for event_type, state_keys in db_state_filter.types.items(): + if state_keys is None: + keys.append((group, event_type, None)) + else: + for state_key in state_keys: + keys.append((group, event_type, state_key)) + + spawned_request = self._state_group_inflight_cache.set_and_compute( + tuple(keys), get_state_groups_from_groups_then_add_to_cache + ) + for group in inflight_cache_misses: + inflight_requests.append((group, spawned_request)) + + # WAIT FOR IN-FLIGHT REQUESTS TO FINISH + for group, inflight_request in inflight_requests: + request_result = await inflight_request + state[group].update(request_result[group]) + + # ASSEMBLE # And finally update the result dict, by filtering out any extra # stuff we pulled out of the database. - for group, group_state_dict in group_to_state_dict.items(): + for group in groups: # We just replace any existing entries, as we will have loaded # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) + state[group] = state_filter.filter_state(state[group]) return state def _get_state_for_groups_using_cache( self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter - ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: + ) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. diff --git a/synapse/util/caches/multi_key_response_cache.py b/synapse/util/caches/multi_key_response_cache.py new file mode 100644 index 0000000000..1e35c33a4b --- /dev/null +++ b/synapse/util/caches/multi_key_response_cache.py @@ -0,0 +1,210 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Any, Awaitable, Callable, Dict, Generic, Optional, Tuple, TypeVar + +import attr + +from twisted.internet import defer + +from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.util import Clock +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches import register_cache + +logger = logging.getLogger(__name__) + +# the type of the key in the cache +KV = TypeVar("KV") + +# the type of the result from the operation +RV = TypeVar("RV") + + +@attr.s(auto_attribs=True) +class MultiKeyResponseCacheContext(Generic[KV]): + """Information about a missed MultiKeyResponseCache hit + + This object can be passed into the callback for additional feedback + """ + + cache_keys: Tuple[KV, ...] + """The cache key that caused the cache miss + + This should be considered read-only. + + TODO: in attrs 20.1, make it frozen with an on_setattr. + """ + + should_cache: bool = True + """Whether the result should be cached once the request completes. + + This can be modified by the callback if it decides its result should not be cached. + """ + + +class MultiKeyResponseCache(Generic[KV]): + """ + This caches a deferred response. Until the deferred completes it will be + returned from the cache. This means that if the client retries the request + while the response is still being computed, that original response will be + used rather than trying to compute a new response. + + Unlike the plain ResponseCache, this cache admits multiple keys to the + deferred response. + """ + + def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): + # This is poorly-named: it includes both complete and incomplete results. + # We keep complete results rather than switching to absolute values because + # that makes it easier to cache Failure results. + self.pending_result_cache: Dict[KV, ObservableDeferred] = {} + + self.clock = clock + self.timeout_sec = timeout_ms / 1000.0 + + self._name = name + self._metrics = register_cache( + "multikey_response_cache", name, self, resizable=False + ) + + def size(self) -> int: + return len(self.pending_result_cache) + + def __len__(self) -> int: + return self.size() + + def get(self, key: KV) -> Optional[defer.Deferred]: + """Look up the given key. + + Returns a new Deferred (which also doesn't follow the synapse + logcontext rules). You will probably want to make_deferred_yieldable the result. + + If there is no entry for the key, returns None. + + Args: + key: key to get/set in the cache + + Returns: + None if there is no entry for this key; otherwise a deferred which + resolves to the result. + """ + result = self.pending_result_cache.get(key) + if result is not None: + self._metrics.inc_hits() + return result.observe() + else: + self._metrics.inc_misses() + return None + + def _set( + self, context: MultiKeyResponseCacheContext[KV], deferred: defer.Deferred + ) -> defer.Deferred: + """Set the entry for the given key to the given deferred. + + *deferred* should run its callbacks in the sentinel logcontext (ie, + you should wrap normal synapse deferreds with + synapse.logging.context.run_in_background). + + Returns a new Deferred (which also doesn't follow the synapse logcontext rules). + You will probably want to make_deferred_yieldable the result. + + Args: + context: Information about the cache miss + deferred: The deferred which resolves to the result. + + Returns: + A new deferred which resolves to the actual result. + """ + result = ObservableDeferred(deferred, consumeErrors=True) + keys = context.cache_keys + for key in keys: + if key not in self.pending_result_cache: + # we only add the key if it's not already there, since we assume + # that we won't overtake prior entries. + self.pending_result_cache[key] = result + + def on_complete(r): + # if this cache has a non-zero timeout, and the callback has not cleared + # the should_cache bit, we leave it in the cache for now and schedule + # its removal later. + if self.timeout_sec and context.should_cache: + for key in keys: + # TODO sketch, should do this in only one call_later. + self.clock.call_later( + self.timeout_sec, self.pending_result_cache.pop, key, None + ) + else: + for key in keys: + # otherwise, remove the result immediately. + self.pending_result_cache.pop(key, None) + return r + + # make sure we do this *after* adding the entry to pending_result_cache, + # in case the result is already complete (in which case flipping the order would + # leave us with a stuck entry in the cache). + result.addBoth(on_complete) + return result.observe() + + def set_and_compute( + self, + keys: Tuple[KV, ...], + callback: Callable[..., Awaitable[RV]], + *args: Any, + cache_context: bool = False, + **kwargs: Any, + ) -> defer.Deferred[RV]: + """Perform a *set* call, taking care of logcontexts + + Makes a call to *callback(*args, **kwargs)*, which should + follow the synapse logcontext rules, and adds the result to the cache. + + Example usage: + + async def handle_request(request): + # etc + return result + + result = await response_cache.wrap( + key, + handle_request, + request, + ) + + Args: + keys: keys to get/set in the cache + + callback: function to call + + *args: positional parameters to pass to the callback, if it is used + + cache_context: if set, the callback will be given a `cache_context` kw arg, + which will be a ResponseCacheContext object. + + **kwargs: named parameters to pass to the callback, if it is used + + Returns: + The result of the callback (from the cache, or otherwise) + """ + + # TODO sketch logger.debug( + # "[%s]: no cached result for [%s], calculating new one", self._name, key + # ) + context = MultiKeyResponseCacheContext(cache_keys=keys) + if cache_context: + kwargs["cache_context"] = context + d = run_in_background(callback, *args, **kwargs) + result = self._set(context, d) + + return make_deferred_yieldable(result) |