From 9f87da0a84f93096e228f01f1139c9b5db8ea3d4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 19:43:37 +0100 Subject: Rename Cache->DeferredCache --- tests/util/caches/test_descriptors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests/util') diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 677e925477..bd870b4a33 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -42,9 +42,9 @@ def run_on_reactor(): return make_deferred_yieldable(d) -class CacheTestCase(unittest.TestCase): +class DeferredCacheTestCase(unittest.TestCase): def test_invalidate_all(self): - cache = descriptors.Cache("testcache") + cache = descriptors.DeferredCache("testcache") callback_record = [False, False] -- cgit 1.5.1 From 4182bb812f21d7231ff0efc5e93e5f2e88f6605e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 23:25:23 +0100 Subject: move DeferredCache into its own module --- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/storage/databases/main/client_ips.py | 2 +- synapse/storage/databases/main/devices.py | 3 +- synapse/storage/databases/main/events_worker.py | 3 +- synapse/util/caches/deferred_cache.py | 292 ++++++++++++++++++++++++ synapse/util/caches/descriptors.py | 284 +---------------------- tests/storage/test__base.py | 3 +- tests/test_metrics.py | 2 +- tests/util/caches/test_deferred_cache.py | 64 ++++++ tests/util/caches/test_descriptors.py | 44 ---- 10 files changed, 367 insertions(+), 332 deletions(-) create mode 100644 synapse/util/caches/deferred_cache.py create mode 100644 tests/util/caches/test_deferred_cache.py (limited to 'tests/util') diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 40ea78a353..4b0ea0cc01 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache from ._base import BaseSlavedStore diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index ad32701d36..9e66e6648a 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d903155e89..e662a20d24 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,7 +34,8 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.descriptors import DeferredCache, cached, cachedList +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index be7f60f2e8..ff150f0be7 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import DeferredCache, cached +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py new file mode 100644 index 0000000000..f728cd2cf2 --- /dev/null +++ b/synapse/util/caches/deferred_cache.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import threading +from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast + +from prometheus_client import Gauge + +from twisted.internet import defer + +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches import register_cache +from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry + +cache_pending_metric = Gauge( + "synapse_util_caches_cache_pending", + "Number of lookups currently pending for this cache", + ["name"], +) + + +KT = TypeVar("KT") +VT = TypeVar("VT") + + +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + +class DeferredCache(Generic[KT, VT]): + """Wraps an LruCache, adding support for Deferred results. + + It expects that each entry added with set() will be a Deferred; likewise get() + may return an ObservableDeferred. + """ + + __slots__ = ( + "cache", + "name", + "keylen", + "thread", + "metrics", + "_pending_deferred_cache", + ) + + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key. Ignored unless + `tree` is True. + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + """ + cache_type = TreeCache if tree else dict + + # _pending_deferred_cache maps from the key value to a `CacheEntry` object. + self._pending_deferred_cache = ( + cache_type() + ) # type: MutableMapping[KT, CacheEntry] + + # cache is used for completed results and maps to the result itself, rather than + # a Deferred. + self.cache = LruCache( + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, + size_callback=(lambda d: len(d)) if iterable else None, + evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, + ) + + self.name = name + self.keylen = keylen + self.thread = None # type: Optional[threading.Thread] + self.metrics = register_cache( + "cache", + name, + self.cache, + collect_callback=self._metrics_collection_callback, + ) + + @property + def max_entries(self): + return self.cache.max_size + + def _on_evicted(self, evicted_count): + self.metrics.inc_evictions(evicted_count) + + def _metrics_collection_callback(self): + cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get( + self, + key: KT, + default=_Sentinel.sentinel, + callback: Optional[Callable[[], None]] = None, + update_metrics: bool = True, + ): + """Looks the key up in the caches. + + Args: + key(tuple) + default: What is returned if key is not in the caches. If not + specified then function throws KeyError instead + callback(fn): Gets called when the entry in the cache is invalidated + update_metrics (bool): whether to update the cache hit rate metrics + + Returns: + Either an ObservableDeferred or the result itself + """ + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if val is not _Sentinel.sentinel: + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) + if val is not _Sentinel.sentinel: + self.metrics.inc_hits() + return val + + if update_metrics: + self.metrics.inc_misses() + + if default is _Sentinel.sentinel: + raise KeyError() + else: + return default + + def set( + self, + key: KT, + value: defer.Deferred, + callback: Optional[Callable[[], None]] = None, + ) -> ObservableDeferred: + if not isinstance(value, defer.Deferred): + raise TypeError("not a Deferred") + + callbacks = [callback] if callback else [] + self.check_thread() + observable = ObservableDeferred(value, consumeErrors=True) + observer = observable.observe() + entry = CacheEntry(deferred=observable, callbacks=callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def compare_and_pop(): + """Check if our entry is still the one in _pending_deferred_cache, and + if so, pop it. + + Returns true if the entries matched. + """ + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + return True + + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + return False + + def cb(result): + if compare_and_pop(): + self.cache.set(key, result, entry.callbacks) + else: + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. + entry.invalidate() + + def eb(_fail): + compare_and_pop() + entry.invalidate() + + # once the deferred completes, we can move the entry from the + # _pending_deferred_cache to the real cache. + # + observer.addCallbacks(cb, eb) + return observable + + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) + + def invalidate(self, key): + self.check_thread() + self.cache.pop(key, None) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. + entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. + if entry: + entry.invalidate() + + def invalidate_many(self, key: KT): + self.check_thread() + if not isinstance(key, tuple): + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + self.cache.del_multi(key) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above + entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) + if entry_dict is not None: + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + + def invalidate_all(self): + self.check_thread() + self.cache.clear() + for entry in self._pending_deferred_cache.values(): + entry.invalidate() + self._pending_deferred_cache.clear() + + +class CacheEntry: + __slots__ = ["deferred", "callbacks", "invalidated"] + + def __init__( + self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] + ): + self.deferred = deferred + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 7c9fe199bf..1f43886804 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,44 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import functools import inspect import logging -import threading -from typing import ( - Any, - Callable, - Generic, - Iterable, - MutableMapping, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary -from prometheus_client import Gauge - from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry - -from . import register_cache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) -KT = TypeVar("KT") -VT = TypeVar("VT") class _CachedFunction(Generic[F]): @@ -68,266 +48,6 @@ class _CachedFunction(Generic[F]): __call__ = None # type: F -cache_pending_metric = Gauge( - "synapse_util_caches_cache_pending", - "Number of lookups currently pending for this cache", - ["name"], -) - - -class _Sentinel(enum.Enum): - # defining a sentinel in this way allows mypy to correctly handle the - # type of a dictionary lookup. - sentinel = object() - - -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] - - def __init__( - self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] - ): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self): - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() - - -class DeferredCache(Generic[KT, VT]): - """Wraps an LruCache, adding support for Deferred results. - - It expects that each entry added with set() will be a Deferred; likewise get() - may return an ObservableDeferred. - """ - - __slots__ = ( - "cache", - "name", - "keylen", - "thread", - "metrics", - "_pending_deferred_cache", - ) - - def __init__( - self, - name: str, - max_entries: int = 1000, - keylen: int = 1, - tree: bool = False, - iterable: bool = False, - apply_cache_factor_from_config: bool = True, - ): - """ - Args: - name: The name of the cache - max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key. Ignored unless - `tree` is True. - tree: Use a TreeCache instead of a dict as the underlying cache type - iterable: If True, count each item in the cached object as an entry, - rather than each cached object - apply_cache_factor_from_config: Whether cache factors specified in the - config file affect `max_entries` - """ - cache_type = TreeCache if tree else dict - - # _pending_deferred_cache maps from the key value to a `CacheEntry` object. - self._pending_deferred_cache = ( - cache_type() - ) # type: MutableMapping[KT, CacheEntry] - - # cache is used for completed results and maps to the result itself, rather than - # a Deferred. - self.cache = LruCache( - max_size=max_entries, - keylen=keylen, - cache_type=cache_type, - size_callback=(lambda d: len(d)) if iterable else None, - evicted_callback=self._on_evicted, - apply_cache_factor_from_config=apply_cache_factor_from_config, - ) - - self.name = name - self.keylen = keylen - self.thread = None # type: Optional[threading.Thread] - self.metrics = register_cache( - "cache", - name, - self.cache, - collect_callback=self._metrics_collection_callback, - ) - - @property - def max_entries(self): - return self.cache.max_size - - def _on_evicted(self, evicted_count): - self.metrics.inc_evictions(evicted_count) - - def _metrics_collection_callback(self): - cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get( - self, - key: KT, - default=_Sentinel.sentinel, - callback: Optional[Callable[[], None]] = None, - update_metrics: bool = True, - ): - """Looks the key up in the caches. - - Args: - key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead - callback(fn): Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics - - Returns: - Either an ObservableDeferred or the result itself - """ - callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) - if val is not _Sentinel.sentinel: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred - - val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) - if val is not _Sentinel.sentinel: - self.metrics.inc_hits() - return val - - if update_metrics: - self.metrics.inc_misses() - - if default is _Sentinel.sentinel: - raise KeyError() - else: - return default - - def set( - self, - key: KT, - value: defer.Deferred, - callback: Optional[Callable[[], None]] = None, - ) -> ObservableDeferred: - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] - self.check_thread() - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) - - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() - - self._pending_deferred_cache[key] = entry - - def compare_and_pop(): - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result): - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail): - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) - return observable - - def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): - callbacks = [callback] if callback else [] - self.cache.set(key, value, callbacks=callbacks) - - def invalidate(self, key): - self.check_thread() - self.cache.pop(key, None) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry - # in self.cache. - entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. - if entry: - entry.invalidate() - - def invalidate_many(self, key: KT): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError("The cache key must be a tuple not %r" % (type(key),)) - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) - if entry_dict is not None: - for entry in iterate_tree_cache_entry(entry_dict): - entry.invalidate() - - def invalidate_all(self): - self.check_thread() - self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() - self._pending_deferred_cache.clear() - - class _CacheDescriptorBase: def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 00adcab7b9..2598dbe0a7 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,7 +20,8 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import DeferredCache, cached +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached from tests import unittest diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1c03a52f7c..759e4cd048 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache from tests import unittest diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py new file mode 100644 index 0000000000..9b6acdfc43 --- /dev/null +++ b/tests/util/caches/test_deferred_cache.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +from twisted.internet import defer + +import synapse.util.caches.deferred_cache + + +class DeferredCacheTestCase(unittest.TestCase): + def test_invalidate_all(self): + cache = synapse.util.caches.deferred_cache.DeferredCache("testcache") + + callback_record = [False, False] + + def record_callback(idx): + callback_record[idx] = True + + # add a couple of pending entries + d1 = defer.Deferred() + cache.set("key1", d1, partial(record_callback, 0)) + + d2 = defer.Deferred() + cache.set("key2", d2, partial(record_callback, 1)) + + # lookup should return observable deferreds + self.assertFalse(cache.get("key1").has_called()) + self.assertFalse(cache.get("key2").has_called()) + + # let one of the lookups complete + d2.callback("result2") + + # for now at least, the cache will return real results rather than an + # observabledeferred + self.assertEqual(cache.get("key2"), "result2") + + # now do the invalidation + cache.invalidate_all() + + # lookup should return none + self.assertIsNone(cache.get("key1", None)) + self.assertIsNone(cache.get("key2", None)) + + # both callbacks should have been callbacked + self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") + self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") + + # letting the other lookup complete should do nothing + d1.callback("result1") + self.assertIsNone(cache.get("key1", None)) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index bd870b4a33..3d1f960869 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from functools import partial import mock @@ -42,49 +41,6 @@ def run_on_reactor(): return make_deferred_yieldable(d) -class DeferredCacheTestCase(unittest.TestCase): - def test_invalidate_all(self): - cache = descriptors.DeferredCache("testcache") - - callback_record = [False, False] - - def record_callback(idx): - callback_record[idx] = True - - # add a couple of pending entries - d1 = defer.Deferred() - cache.set("key1", d1, partial(record_callback, 0)) - - d2 = defer.Deferred() - cache.set("key2", d2, partial(record_callback, 1)) - - # lookup should return observable deferreds - self.assertFalse(cache.get("key1").has_called()) - self.assertFalse(cache.get("key2").has_called()) - - # let one of the lookups complete - d2.callback("result2") - - # for now at least, the cache will return real results rather than an - # observabledeferred - self.assertEqual(cache.get("key2"), "result2") - - # now do the invalidation - cache.invalidate_all() - - # lookup should return none - self.assertIsNone(cache.get("key1", None)) - self.assertIsNone(cache.get("key2", None)) - - # both callbacks should have been callbacked - self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") - self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") - - # letting the other lookup complete should do nothing - d1.callback("result1") - self.assertIsNone(cache.get("key1", None)) - - class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): -- cgit 1.5.1 From 470dedd2662536c309407d05085d04a7d61c5de8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 23:37:23 +0100 Subject: Combine the two sets of DeferredCache tests --- tests/storage/test__base.py | 72 ----------------------------- tests/util/caches/test_deferred_cache.py | 77 +++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 74 deletions(-) (limited to 'tests/util') diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 2598dbe0a7..8e69b1e9cc 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,83 +20,11 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.descriptors import cached from tests import unittest -class DeferredCacheTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - self.cache = DeferredCache("test") - - def test_empty(self): - failed = False - try: - self.cache.get("foo") - except KeyError: - failed = True - - self.assertTrue(failed) - - def test_hit(self): - self.cache.prefill("foo", 123) - - self.assertEquals(self.cache.get("foo"), 123) - - def test_invalidate(self): - self.cache.prefill(("foo",), 123) - self.cache.invalidate(("foo",)) - - failed = False - try: - self.cache.get(("foo",)) - except KeyError: - failed = True - - self.assertTrue(failed) - - def test_eviction(self): - cache = DeferredCache("test", max_entries=2) - - cache.prefill(1, "one") - cache.prefill(2, "two") - cache.prefill(3, "three") # 1 will be evicted - - failed = False - try: - cache.get(1) - except KeyError: - failed = True - - self.assertTrue(failed) - - cache.get(2) - cache.get(3) - - def test_eviction_lru(self): - cache = DeferredCache("test", max_entries=2) - - cache.prefill(1, "one") - cache.prefill(2, "two") - - # Now access 1 again, thus causing 2 to be least-recently used - cache.get(1) - - cache.prefill(3, "three") - - failed = False - try: - cache.get(2) - except KeyError: - failed = True - - self.assertTrue(failed) - - cache.get(1) - cache.get(3) - - class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 9b6acdfc43..9717be56b6 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -18,12 +18,41 @@ from functools import partial from twisted.internet import defer -import synapse.util.caches.deferred_cache +from synapse.util.caches.deferred_cache import DeferredCache class DeferredCacheTestCase(unittest.TestCase): + def test_empty(self): + cache = DeferredCache("test") + failed = False + try: + cache.get("foo") + except KeyError: + failed = True + + self.assertTrue(failed) + + def test_hit(self): + cache = DeferredCache("test") + cache.prefill("foo", 123) + + self.assertEquals(cache.get("foo"), 123) + + def test_invalidate(self): + cache = DeferredCache("test") + cache.prefill(("foo",), 123) + cache.invalidate(("foo",)) + + failed = False + try: + cache.get(("foo",)) + except KeyError: + failed = True + + self.assertTrue(failed) + def test_invalidate_all(self): - cache = synapse.util.caches.deferred_cache.DeferredCache("testcache") + cache = DeferredCache("testcache") callback_record = [False, False] @@ -62,3 +91,47 @@ class DeferredCacheTestCase(unittest.TestCase): # letting the other lookup complete should do nothing d1.callback("result1") self.assertIsNone(cache.get("key1", None)) + + def test_eviction(self): + cache = DeferredCache( + "test", max_entries=2, apply_cache_factor_from_config=False + ) + + cache.prefill(1, "one") + cache.prefill(2, "two") + cache.prefill(3, "three") # 1 will be evicted + + failed = False + try: + cache.get(1) + except KeyError: + failed = True + + self.assertTrue(failed) + + cache.get(2) + cache.get(3) + + def test_eviction_lru(self): + cache = DeferredCache( + "test", max_entries=2, apply_cache_factor_from_config=False + ) + + cache.prefill(1, "one") + cache.prefill(2, "two") + + # Now access 1 again, thus causing 2 to be least-recently used + cache.get(1) + + cache.prefill(3, "three") + + failed = False + try: + cache.get(2) + except KeyError: + failed = True + + self.assertTrue(failed) + + cache.get(1) + cache.get(3) -- cgit 1.5.1 From 3ee17585cd095e590096683395cfb9a017eac15e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 16 Oct 2020 15:51:57 +0100 Subject: Make LruCache register its own metrics (#8561) rather than have everything that instantiates an LruCache manage metrics separately, have LruCache do it itself. --- changelog.d/8561.misc | 1 + synapse/api/auth.py | 4 +-- synapse/push/push_rule_evaluator.py | 4 +-- synapse/util/caches/__init__.py | 13 ++++++---- synapse/util/caches/deferred_cache.py | 43 ++++++++++-------------------- synapse/util/caches/dictionary_cache.py | 9 +------ synapse/util/caches/lrucache.py | 46 +++++++++++++++++++++++++-------- tests/util/test_lrucache.py | 4 +-- 8 files changed, 62 insertions(+), 62 deletions(-) create mode 100644 changelog.d/8561.misc (limited to 'tests/util') diff --git a/changelog.d/8561.misc b/changelog.d/8561.misc new file mode 100644 index 0000000000..a40dedfa8e --- /dev/null +++ b/changelog.d/8561.misc @@ -0,0 +1 @@ +Move metric registration code down into `LruCache`. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1071a0576e..eb6f418b13 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -34,7 +34,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging import opentracing as opentracing from synapse.types import StateMap, UserID -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -70,8 +69,7 @@ class Auth: self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(10000) - register_cache("cache", "token_cache", self.token_cache) + self.token_cache = LruCache(10000, "token_cache") self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 3a68ce636f..4c95b149c5 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -20,7 +20,6 @@ from typing import Any, Dict, List, Optional, Pattern, Union from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -186,8 +185,7 @@ class PushRuleEvaluatorForEvent: # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000) -register_cache("cache", "regex_push_cache", regex_cache) +regex_cache = LruCache(50000, "regex_push_cache") def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 8fc05be278..89f0b38535 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -16,7 +16,7 @@ import logging from sys import intern -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Sized import attr from prometheus_client.core import Gauge @@ -92,7 +92,7 @@ class CacheMetric: def register_cache( cache_type: str, cache_name: str, - cache, + cache: Sized, collect_callback: Optional[Callable] = None, resizable: bool = True, resize_callback: Optional[Callable] = None, @@ -100,12 +100,15 @@ def register_cache( """Register a cache object for metric collection and resizing. Args: - cache_type + cache_type: a string indicating the "type" of the cache. This is used + only for deduplication so isn't too important provided it's constant. cache_name: name of the cache - cache: cache itself + cache: cache itself, which must implement __len__(), and may optionally implement + a max_size property collect_callback: If given, a function which is called during metric collection to update additional metrics. - resizable: Whether this cache supports being resized. + resizable: Whether this cache supports being resized, in which case either + resize_callback must be provided, or the cache must support set_max_size(). resize_callback: A function which can be called to resize the cache. Returns: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index f728cd2cf2..91fdc8142d 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -24,7 +24,6 @@ from prometheus_client import Gauge from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -54,10 +53,7 @@ class DeferredCache(Generic[KT, VT]): __slots__ = ( "cache", - "name", - "keylen", "thread", - "metrics", "_pending_deferred_cache", ) @@ -89,37 +85,27 @@ class DeferredCache(Generic[KT, VT]): cache_type() ) # type: MutableMapping[KT, CacheEntry] + def metrics_cb(): + cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) + # cache is used for completed results and maps to the result itself, rather than # a Deferred. self.cache = LruCache( max_size=max_entries, keylen=keylen, + cache_name=name, cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, - evicted_callback=self._on_evicted, + metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, ) - self.name = name - self.keylen = keylen self.thread = None # type: Optional[threading.Thread] - self.metrics = register_cache( - "cache", - name, - self.cache, - collect_callback=self._metrics_collection_callback, - ) @property def max_entries(self): return self.cache.max_size - def _on_evicted(self, evicted_count): - self.metrics.inc_evictions(evicted_count) - - def _metrics_collection_callback(self): - cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) - def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -154,21 +140,18 @@ class DeferredCache(Generic[KT, VT]): if val is not _Sentinel.sentinel: val.callbacks.update(callbacks) if update_metrics: - self.metrics.inc_hits() + m = self.cache.metrics + assert m # we always have a name, so should always have metrics + m.inc_hits() return val.deferred - val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) - if val is not _Sentinel.sentinel: - self.metrics.inc_hits() - return val - - if update_metrics: - self.metrics.inc_misses() - - if default is _Sentinel.sentinel: + val = self.cache.get( + key, default, callbacks=callbacks, update_metrics=update_metrics + ) + if val is _Sentinel.sentinel: raise KeyError() else: - return default + return val def set( self, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 8592b93689..8b426c005b 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -19,8 +19,6 @@ from collections import namedtuple from synapse.util.caches.lrucache import LruCache -from . import register_cache - logger = logging.getLogger(__name__) @@ -46,18 +44,16 @@ class DictionaryCache: """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries, size_callback=len) + self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len) self.name = name self.sequence = 0 self.thread = None - # caches_by_name[name] = self.cache class Sentinel: __slots__ = [] self.sentinel = Sentinel() - self.metrics = register_cache("dictionary", name, self.cache) def check_thread(self): expected_thread = self.thread @@ -82,8 +78,6 @@ class DictionaryCache: """ entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: - self.metrics.inc_hits() - if dict_keys is None: return DictionaryEntry( entry.full, entry.known_absent, dict(entry.value) @@ -95,7 +89,6 @@ class DictionaryCache: {k: entry.value[k] for k in dict_keys if k in entry.value}, ) - self.metrics.inc_misses() return DictionaryEntry(False, set(), {}) def invalidate(self, key): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 33eae2b7c4..e4804f79e0 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -18,6 +18,7 @@ from functools import wraps from typing import Callable, Optional, Type, Union from synapse.config import cache as cache_config +from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache @@ -43,27 +44,29 @@ class _Node: class LruCache: """ - Least-recently-used cache. + Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. + Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. - - Can also set callbacks on objects when getting/setting which are fired - when that key gets invalidated/evicted. """ def __init__( self, max_size: int, + cache_name: Optional[str] = None, keylen: int = 1, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable] = None, - evicted_callback: Optional[Callable] = None, + metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, ): """ Args: max_size: The maximum amount of entries the cache can hold + cache_name: The name of this cache, for the prometheus metrics. If unset, + no metrics will be reported on this cache. + keylen: The length of the tuple used as the cache key. Ignored unless cache_type is `TreeCache`. @@ -73,9 +76,13 @@ class LruCache: size_callback (func(V) -> int | None): - evicted_callback (func(int)|None): - if not None, called on eviction with the size of the evicted - entry + metrics_collection_callback: + metrics collection callback. This is called early in the metrics + collection process, before any of the metrics registered with the + prometheus Registry are collected, so can be used to update any dynamic + metrics. + + Ignored if cache_name is None. apply_cache_factor_from_config (bool): If true, `max_size` will be multiplied by a cache factor derived from the homeserver config @@ -94,6 +101,19 @@ class LruCache: else: self.max_size = int(max_size) + if cache_name is not None: + metrics = register_cache( + "lru_cache", + cache_name, + self, + collect_callback=metrics_collection_callback, + ) # type: Optional[CacheMetric] + else: + metrics = None + + # this is exposed for access from outside this class + self.metrics = metrics + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -105,8 +125,8 @@ class LruCache: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) - if evicted_callback: - evicted_callback(evicted_len) + if metrics: + metrics.inc_evictions(evicted_len) def synchronized(f): @wraps(f) @@ -169,13 +189,17 @@ class LruCache: return deleted_len @synchronized - def cache_get(key, default=None, callbacks=[]): + def cache_get(key, default=None, callbacks=[], update_metrics=True): node = cache.get(key, None) if node is not None: move_node_to_front(node) node.callbacks.update(callbacks) + if update_metrics and metrics: + metrics.inc_hits() return node.value else: + if update_metrics and metrics: + metrics.inc_misses() return default @synchronized diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 0adb2174af..f12834edab 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -59,7 +59,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEquals(cache.pop("key"), None) def test_del_multi(self): - cache = LruCache(4, 2, cache_type=TreeCache) + cache = LruCache(4, keylen=2, cache_type=TreeCache) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" cache[("vehicles", "car")] = "vroom" @@ -160,7 +160,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): m2 = Mock() m3 = Mock() m4 = Mock() - cache = LruCache(4, 2, cache_type=TreeCache) + cache = LruCache(4, keylen=2, cache_type=TreeCache) cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "2"), "value", callbacks=[m2]) -- cgit 1.5.1 From 903d11c43a5df9f704e5dad4d14506a6470524fc Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 19 Oct 2020 15:00:12 +0100 Subject: Add `DeferredCache.get_immediate` method (#8568) * Add `DeferredCache.get_immediate` method A bunch of things that are currently calling `DeferredCache.get` are only really interested in the result if it's completed. We can optimise and simplify this case. * Remove unused 'default' parameter to DeferredCache.get() * another get_immediate instance --- changelog.d/8568.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/storage/databases/main/pusher.py | 2 +- synapse/storage/databases/main/receipts.py | 11 +-------- synapse/storage/databases/main/roommember.py | 2 +- synapse/util/caches/deferred_cache.py | 35 ++++++++++++++++++++-------- tests/util/caches/test_deferred_cache.py | 27 +++++++++++++++++---- 7 files changed, 53 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8568.misc (limited to 'tests/util') diff --git a/changelog.d/8568.misc b/changelog.d/8568.misc new file mode 100644 index 0000000000..0ed7db92d3 --- /dev/null +++ b/changelog.d/8568.misc @@ -0,0 +1 @@ +Add `get_immediate` method to `DeferredCache`. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index c440f2545c..a701defcdd 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): # dedupe when we add callbacks to lru cache nodes, otherwise the number # of callbacks would grow. def __call__(self): - rules = self.cache.get(self.room_id, None, update_metrics=False) + rules = self.cache.get_immediate(self.room_id, None, update_metrics=False) if rules: rules.invalidate_all() diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index df8609b97b..7997242d90 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore): lock=False, ) - user_has_pusher = self.get_if_user_has_pusher.cache.get( + user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate( (user_id,), None, update_metrics=False ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 5cdf16521c..ca7917c989 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): if receipt_type != "m.read": return - # Returns either an ObservableDeferred or the raw result - res = self.get_users_with_read_receipts_in_room.cache.get( + res = self.get_users_with_read_receipts_in_room.cache.get_immediate( room_id, None, update_metrics=False ) - # first handle the ObservableDeferred case - if isinstance(res, ObservableDeferred): - if res.has_called(): - res = res.get_result() - else: - res = None - if res and user_id in res: # We'd only be adding to the set, so no point invalidating if the # user is already there diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 20fcdaa529..9b08b49862 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -531,7 +531,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get( + prev_res = self._get_joined_users_from_context.cache.get_immediate( (room_id, context.prev_group), None ) if prev_res and isinstance(prev_res, dict): diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 4026e1f8fa..faeef75506 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -17,7 +17,16 @@ import enum import threading -from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast +from typing import ( + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + TypeVar, + Union, + cast, +) from prometheus_client import Gauge @@ -33,7 +42,7 @@ cache_pending_metric = Gauge( ["name"], ) - +T = TypeVar("T") KT = TypeVar("KT") VT = TypeVar("VT") @@ -119,21 +128,21 @@ class DeferredCache(Generic[KT, VT]): def get( self, key: KT, - default=_Sentinel.sentinel, callback: Optional[Callable[[], None]] = None, update_metrics: bool = True, - ): + ) -> Union[ObservableDeferred, VT]: """Looks the key up in the caches. Args: key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead callback(fn): Gets called when the entry in the cache is invalidated update_metrics (bool): whether to update the cache hit rate metrics Returns: Either an ObservableDeferred or the result itself + + Raises: + KeyError if the key is not found in the cache """ callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) @@ -145,13 +154,19 @@ class DeferredCache(Generic[KT, VT]): m.inc_hits() return val.deferred - val = self.cache.get( - key, default, callbacks=callbacks, update_metrics=update_metrics + val2 = self.cache.get( + key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics ) - if val is _Sentinel.sentinel: + if val2 is _Sentinel.sentinel: raise KeyError() else: - return val + return val2 + + def get_immediate( + self, key: KT, default: T, update_metrics: bool = True + ) -> Union[VT, T]: + """If we have a *completed* cached value, return it.""" + return self.cache.get(key, default, update_metrics=update_metrics) def set( self, diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 9717be56b6..8a08ab6661 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -38,6 +38,22 @@ class DeferredCacheTestCase(unittest.TestCase): self.assertEquals(cache.get("foo"), 123) + def test_get_immediate(self): + cache = DeferredCache("test") + d1 = defer.Deferred() + cache.set("key1", d1) + + # get_immediate should return default + v = cache.get_immediate("key1", 1) + self.assertEqual(v, 1) + + # now complete the set + d1.callback(2) + + # get_immediate should return result + v = cache.get_immediate("key1", 1) + self.assertEqual(v, 2) + def test_invalidate(self): cache = DeferredCache("test") cache.prefill(("foo",), 123) @@ -80,9 +96,11 @@ class DeferredCacheTestCase(unittest.TestCase): # now do the invalidation cache.invalidate_all() - # lookup should return none - self.assertIsNone(cache.get("key1", None)) - self.assertIsNone(cache.get("key2", None)) + # lookup should fail + with self.assertRaises(KeyError): + cache.get("key1") + with self.assertRaises(KeyError): + cache.get("key2") # both callbacks should have been callbacked self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") @@ -90,7 +108,8 @@ class DeferredCacheTestCase(unittest.TestCase): # letting the other lookup complete should do nothing d1.callback("result1") - self.assertIsNone(cache.get("key1", None)) + with self.assertRaises(KeyError): + cache.get("key1", None) def test_eviction(self): cache = DeferredCache( -- cgit 1.5.1 From 96e7d3c4a0feec6d19b873fd550bcfffd485d910 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 19 Oct 2020 21:13:50 +0100 Subject: Fix 'LruCache' object has no attribute '_on_resize' (#8591) We need to make sure we are readu for the `set_cache_factor` callback. --- changelog.d/8591.misc | 1 + synapse/util/caches/lrucache.py | 10 +++++++++- tests/util/test_lrucache.py | 8 +++++++- 3 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8591.misc (limited to 'tests/util') diff --git a/changelog.d/8591.misc b/changelog.d/8591.misc new file mode 100644 index 0000000000..8f16bc3e7e --- /dev/null +++ b/changelog.d/8591.misc @@ -0,0 +1 @@ + Move metric registration code down into `LruCache`. diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 3b471d8fd3..60bb6ff642 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -124,6 +124,10 @@ class LruCache(Generic[KT, VT]): else: self.max_size = int(max_size) + # register_cache might call our "set_cache_factor" callback; there's nothing to + # do yet when we get resized. + self._on_resize = None # type: Optional[Callable[[],None]] + if cache_name is not None: metrics = register_cache( "lru_cache", @@ -332,7 +336,10 @@ class LruCache(Generic[KT, VT]): return key in cache self.sentinel = object() + + # make sure that we clear out any excess entries after we get resized. self._on_resize = evict + self.get = cache_get self.set = cache_set self.setdefault = cache_set_default @@ -383,6 +390,7 @@ class LruCache(Generic[KT, VT]): new_size = int(self._original_max_size * factor) if new_size != self.max_size: self.max_size = new_size - self._on_resize() + if self._on_resize: + self._on_resize() return True return False diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index f12834edab..a739a6aaaf 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -19,7 +19,8 @@ from mock import Mock from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache -from .. import unittest +from tests import unittest +from tests.unittest import override_config class LruCacheTestCase(unittest.HomeserverTestCase): @@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase): cache.clear() self.assertEquals(len(cache), 0) + @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) + def test_special_size(self): + cache = LruCache(10, "mycache") + self.assertEqual(cache.max_size, 100) + class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): -- cgit 1.5.1 From 7b71695388fa2edd7ea5fd946b3d2afb68f4ef9d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 16 Oct 2020 22:31:16 +0100 Subject: Combine the two sets of tests for CacheDescriptor --- tests/storage/test__base.py | 228 --------------------------------- tests/util/caches/test_descriptors.py | 230 ++++++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+), 228 deletions(-) (limited to 'tests/util') diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8e69b1e9cc..1ac4ebc61d 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -15,237 +15,9 @@ # limitations under the License. -from mock import Mock - -from twisted.internet import defer - -from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import cached - from tests import unittest -class CacheDecoratorTestCase(unittest.HomeserverTestCase): - @defer.inlineCallbacks - def test_passthrough(self): - class A: - @cached() - def func(self, key): - return key - - a = A() - - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals((yield a.func("bar")), "bar") - - @defer.inlineCallbacks - def test_hit(self): - callcount = [0] - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key - - a = A() - yield a.func("foo") - - self.assertEquals(callcount[0], 1) - - self.assertEquals((yield a.func("foo")), "foo") - self.assertEquals(callcount[0], 1) - - @defer.inlineCallbacks - def test_invalidate(self): - callcount = [0] - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key - - a = A() - yield a.func("foo") - - self.assertEquals(callcount[0], 1) - - a.func.invalidate(("foo",)) - - yield a.func("foo") - - self.assertEquals(callcount[0], 2) - - def test_invalidate_missing(self): - class A: - @cached() - def func(self, key): - return key - - A().func.invalidate(("what",)) - - @defer.inlineCallbacks - def test_max_entries(self): - callcount = [0] - - class A: - @cached(max_entries=10) - def func(self, key): - callcount[0] += 1 - return key - - a = A() - - for k in range(0, 12): - yield a.func(k) - - self.assertEquals(callcount[0], 12) - - # There must have been at least 2 evictions, meaning if we calculate - # all 12 values again, we must get called at least 2 more times - for k in range(0, 12): - yield a.func(k) - - self.assertTrue( - callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) - ) - - def test_prefill(self): - callcount = [0] - - d = defer.succeed(123) - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return d - - a = A() - - a.func.prefill(("foo",), ObservableDeferred(d)) - - self.assertEquals(a.func("foo").result, d.result) - self.assertEquals(callcount[0], 0) - - @defer.inlineCallbacks - def test_invalidate_context(self): - callcount = [0] - callcount2 = [0] - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key - - @cached(cache_context=True) - def func2(self, key, cache_context): - callcount2[0] += 1 - return self.func(key, on_invalidate=cache_context.invalidate) - - a = A() - yield a.func2("foo") - - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) - - a.func.invalidate(("foo",)) - yield a.func("foo") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 1) - - yield a.func2("foo") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) - - @defer.inlineCallbacks - def test_eviction_context(self): - callcount = [0] - callcount2 = [0] - - class A: - @cached(max_entries=2) - def func(self, key): - callcount[0] += 1 - return key - - @cached(cache_context=True) - def func2(self, key, cache_context): - callcount2[0] += 1 - return self.func(key, on_invalidate=cache_context.invalidate) - - a = A() - yield a.func2("foo") - yield a.func2("foo2") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) - - yield a.func2("foo") - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) - - yield a.func("foo3") - - self.assertEquals(callcount[0], 3) - self.assertEquals(callcount2[0], 2) - - yield a.func2("foo") - - self.assertEquals(callcount[0], 4) - self.assertEquals(callcount2[0], 3) - - @defer.inlineCallbacks - def test_double_get(self): - callcount = [0] - callcount2 = [0] - - class A: - @cached() - def func(self, key): - callcount[0] += 1 - return key - - @cached(cache_context=True) - def func2(self, key, cache_context): - callcount2[0] += 1 - return self.func(key, on_invalidate=cache_context.invalidate) - - a = A() - a.func2.cache.cache = Mock(wraps=a.func2.cache.cache) - - yield a.func2("foo") - - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 1) - - a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 1) - - yield a.func2("foo") - a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 2) - - self.assertEquals(callcount[0], 1) - self.assertEquals(callcount2[0], 2) - - a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 3) - yield a.func("foo") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 2) - - yield a.func2("foo") - - self.assertEquals(callcount[0], 2) - self.assertEquals(callcount2[0], 3) - - class UpsertManyTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.storage = hs.get_datastore() diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 3d1f960869..3d738afa7f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -27,6 +27,7 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import descriptors from synapse.util.caches.descriptors import cached @@ -311,6 +312,235 @@ class DescriptorTestCase(unittest.TestCase): self.failureResultOf(d, SynapseError) +class CacheDecoratorTestCase(unittest.HomeserverTestCase): + """More tests for @cached + + The following is a set of tests that got lost in a different file for a while. + + There are probably duplicates of the tests in DescriptorTestCase. Ideally the + duplicates would be removed and the two sets of classes combined. + """ + + @defer.inlineCallbacks + def test_passthrough(self): + class A: + @cached() + def func(self, key): + return key + + a = A() + + self.assertEquals((yield a.func("foo")), "foo") + self.assertEquals((yield a.func("bar")), "bar") + + @defer.inlineCallbacks + def test_hit(self): + callcount = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + a = A() + yield a.func("foo") + + self.assertEquals(callcount[0], 1) + + self.assertEquals((yield a.func("foo")), "foo") + self.assertEquals(callcount[0], 1) + + @defer.inlineCallbacks + def test_invalidate(self): + callcount = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + a = A() + yield a.func("foo") + + self.assertEquals(callcount[0], 1) + + a.func.invalidate(("foo",)) + + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + + def test_invalidate_missing(self): + class A: + @cached() + def func(self, key): + return key + + A().func.invalidate(("what",)) + + @defer.inlineCallbacks + def test_max_entries(self): + callcount = [0] + + class A: + @cached(max_entries=10) + def func(self, key): + callcount[0] += 1 + return key + + a = A() + + for k in range(0, 12): + yield a.func(k) + + self.assertEquals(callcount[0], 12) + + # There must have been at least 2 evictions, meaning if we calculate + # all 12 values again, we must get called at least 2 more times + for k in range(0, 12): + yield a.func(k) + + self.assertTrue( + callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) + ) + + def test_prefill(self): + callcount = [0] + + d = defer.succeed(123) + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return d + + a = A() + + a.func.prefill(("foo",), ObservableDeferred(d)) + + self.assertEquals(a.func("foo").result, d.result) + self.assertEquals(callcount[0], 0) + + @defer.inlineCallbacks + def test_invalidate_context(self): + callcount = [0] + callcount2 = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func.invalidate(("foo",)) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 1) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + @defer.inlineCallbacks + def test_eviction_context(self): + callcount = [0] + callcount2 = [0] + + class A: + @cached(max_entries=2) + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + yield a.func2("foo") + yield a.func2("foo2") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func("foo3") + + self.assertEquals(callcount[0], 3) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 4) + self.assertEquals(callcount2[0], 3) + + @defer.inlineCallbacks + def test_double_get(self): + callcount = [0] + callcount2 = [0] + + class A: + @cached() + def func(self, key): + callcount[0] += 1 + return key + + @cached(cache_context=True) + def func2(self, key, cache_context): + callcount2[0] += 1 + return self.func(key, on_invalidate=cache_context.invalidate) + + a = A() + a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 1) + + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + + yield a.func2("foo") + a.func2.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + + self.assertEquals(callcount[0], 1) + self.assertEquals(callcount2[0], 2) + + a.func.invalidate(("foo",)) + self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + yield a.func("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 2) + + yield a.func2("foo") + + self.assertEquals(callcount[0], 2) + self.assertEquals(callcount2[0], 3) + + class CachedListDescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): -- cgit 1.5.1 From 1f4269700c2353263a605856f28ded28501368e1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 16 Oct 2020 12:34:55 +0100 Subject: Push some deferred wrangling down into DeferredCache --- changelog.d/8572.misc | 1 + synapse/util/caches/deferred_cache.py | 57 +++++++++++++++++++++++++++----- synapse/util/caches/descriptors.py | 32 ++++-------------- tests/util/caches/test_deferred_cache.py | 18 +++++----- tests/util/caches/test_descriptors.py | 5 ++- 5 files changed, 67 insertions(+), 46 deletions(-) create mode 100644 changelog.d/8572.misc (limited to 'tests/util') diff --git a/changelog.d/8572.misc b/changelog.d/8572.misc new file mode 100644 index 0000000000..ea2a6d340d --- /dev/null +++ b/changelog.d/8572.misc @@ -0,0 +1 @@ +Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s. diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index faeef75506..6c162e9f34 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -57,7 +57,7 @@ class DeferredCache(Generic[KT, VT]): """Wraps an LruCache, adding support for Deferred results. It expects that each entry added with set() will be a Deferred; likewise get() - may return an ObservableDeferred. + will return a Deferred. """ __slots__ = ( @@ -130,16 +130,22 @@ class DeferredCache(Generic[KT, VT]): key: KT, callback: Optional[Callable[[], None]] = None, update_metrics: bool = True, - ) -> Union[ObservableDeferred, VT]: + ) -> defer.Deferred: """Looks the key up in the caches. + For symmetry with set(), this method does *not* follow the synapse logcontext + rules: the logcontext will not be cleared on return, and the Deferred will run + its callbacks in the sentinel context. In other words: wrap the result with + make_deferred_yieldable() before `await`ing it. + Args: - key(tuple) - callback(fn): Gets called when the entry in the cache is invalidated + key: + callback: Gets called when the entry in the cache is invalidated update_metrics (bool): whether to update the cache hit rate metrics Returns: - Either an ObservableDeferred or the result itself + A Deferred which completes with the result. Note that this may later fail + if there is an ongoing set() operation which later completes with a failure. Raises: KeyError if the key is not found in the cache @@ -152,7 +158,7 @@ class DeferredCache(Generic[KT, VT]): m = self.cache.metrics assert m # we always have a name, so should always have metrics m.inc_hits() - return val.deferred + return val.deferred.observe() val2 = self.cache.get( key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics @@ -160,7 +166,7 @@ class DeferredCache(Generic[KT, VT]): if val2 is _Sentinel.sentinel: raise KeyError() else: - return val2 + return defer.succeed(val2) def get_immediate( self, key: KT, default: T, update_metrics: bool = True @@ -173,7 +179,36 @@ class DeferredCache(Generic[KT, VT]): key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None, - ) -> ObservableDeferred: + ) -> defer.Deferred: + """Adds a new entry to the cache (or updates an existing one). + + The given `value` *must* be a Deferred. + + First any existing entry for the same key is invalidated. Then a new entry + is added to the cache for the given key. + + Until the `value` completes, calls to `get()` for the key will also result in an + incomplete Deferred, which will ultimately complete with the same result as + `value`. + + If `value` completes successfully, subsequent calls to `get()` will then return + a completed deferred with the same result. If it *fails*, the cache is + invalidated and subequent calls to `get()` will raise a KeyError. + + If another call to `set()` happens before `value` completes, then (a) any + invalidation callbacks registered in the interim will be called, (b) any + `get()`s in the interim will continue to complete with the result from the + *original* `value`, (c) any future calls to `get()` will complete with the + result from the *new* `value`. + + It is expected that `value` does *not* follow the synapse logcontext rules - ie, + if it is incomplete, it runs its callbacks in the sentinel context. + + Args: + key: Key to be set + value: a deferred which will complete with a result to add to the cache + callback: An optional callback to be called when the entry is invalidated + """ if not isinstance(value, defer.Deferred): raise TypeError("not a Deferred") @@ -187,6 +222,8 @@ class DeferredCache(Generic[KT, VT]): if existing_entry: existing_entry.invalidate() + # XXX: why don't we invalidate the entry in `self.cache` yet? + self._pending_deferred_cache[key] = entry def compare_and_pop(): @@ -230,7 +267,9 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache to the real cache. # observer.addCallbacks(cb, eb) - return observable + + # we return a new Deferred which will be called before any subsequent observers. + return observable.observe() def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): callbacks = [callback] if callback else [] diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1f43886804..a4172345ef 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -23,7 +23,6 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) @@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase): keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) # type: DeferredCache[Tuple, Any] + ) # type: DeferredCache[CacheKey, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into @@ -208,26 +207,12 @@ class CacheDescriptor(_CacheDescriptorBase): kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) try: - cached_result_d = cache.get(cache_key, callback=invalidate_callback) - - if isinstance(cached_result_d, ObservableDeferred): - observer = cached_result_d.observe() - else: - observer = defer.succeed(cached_result_d) - + ret = cache.get(cache_key, callback=invalidate_callback) except KeyError: ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) + ret = cache.set(cache_key, ret, callback=invalidate_callback) - def onErr(f): - cache.invalidate(cache_key) - return f - - ret.addErrback(onErr) - - result_d = cache.set(cache_key, ret, callback=invalidate_callback) - observer = result_d.observe() - - return make_deferred_yieldable(observer) + return make_deferred_yieldable(ret) wrapped = cast(_CachedFunction, _wrapped) @@ -286,7 +271,7 @@ class CacheListDescriptor(_CacheDescriptorBase): def __get__(self, obj, objtype=None): cached_method = getattr(obj, self.cached_method_name) - cache = cached_method.cache + cache = cached_method.cache # type: DeferredCache[CacheKey, Any] num_args = cached_method.num_args @functools.wraps(self.orig) @@ -326,14 +311,11 @@ class CacheListDescriptor(_CacheDescriptorBase): for arg in list_args: try: res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) - if not isinstance(res, ObservableDeferred): - results[arg] = res - elif not res.has_succeeded(): - res = res.observe() + if not res.called: res.addCallback(update_results_dict, arg) cached_defers.append(res) else: - results[arg] = res.get_result() + results[arg] = res.result except KeyError: missing.add(arg) diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 8a08ab6661..68d26128c1 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from functools import partial from twisted.internet import defer from synapse.util.caches.deferred_cache import DeferredCache +from tests.unittest import TestCase -class DeferredCacheTestCase(unittest.TestCase): + +class DeferredCacheTestCase(TestCase): def test_empty(self): cache = DeferredCache("test") failed = False @@ -36,7 +37,7 @@ class DeferredCacheTestCase(unittest.TestCase): cache = DeferredCache("test") cache.prefill("foo", 123) - self.assertEquals(cache.get("foo"), 123) + self.assertEquals(self.successResultOf(cache.get("foo")), 123) def test_get_immediate(self): cache = DeferredCache("test") @@ -82,16 +83,15 @@ class DeferredCacheTestCase(unittest.TestCase): d2 = defer.Deferred() cache.set("key2", d2, partial(record_callback, 1)) - # lookup should return observable deferreds - self.assertFalse(cache.get("key1").has_called()) - self.assertFalse(cache.get("key2").has_called()) + # lookup should return pending deferreds + self.assertFalse(cache.get("key1").called) + self.assertFalse(cache.get("key2").called) # let one of the lookups complete d2.callback("result2") - # for now at least, the cache will return real results rather than an - # observabledeferred - self.assertEqual(cache.get("key2"), "result2") + # now the cache will return a completed deferred + self.assertEqual(self.successResultOf(cache.get("key2")), "result2") # now do the invalidation cache.invalidate_all() diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 3d738afa7f..fc2663c02d 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -27,7 +27,6 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import descriptors from synapse.util.caches.descriptors import cached @@ -419,9 +418,9 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() - a.func.prefill(("foo",), ObservableDeferred(d)) + a.func.prefill(("foo",), 456) - self.assertEquals(a.func("foo").result, d.result) + self.assertEquals(a.func("foo").result, 456) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks -- cgit 1.5.1 From 6d3905c7c7a53eed7a856aa013f6a9bf9292eb7a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 16 Oct 2020 21:32:52 +0100 Subject: Add some more tests --- tests/util/caches/test_deferred_cache.py | 95 ++++++++++++++++++++++++++++++++ tests/util/caches/test_descriptors.py | 52 +++++++++++++++++ 2 files changed, 147 insertions(+) (limited to 'tests/util') diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 68d26128c1..dadfabd46d 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -39,6 +39,101 @@ class DeferredCacheTestCase(TestCase): self.assertEquals(self.successResultOf(cache.get("foo")), 123) + def test_hit_deferred(self): + cache = DeferredCache("test") + origin_d = defer.Deferred() + set_d = cache.set("k1", origin_d) + + # get should return an incomplete deferred + get_d = cache.get("k1") + self.assertFalse(get_d.called) + + # add a callback that will make sure that the set_d gets called before the get_d + def check1(r): + self.assertTrue(set_d.called) + return r + + # TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8. + # maybe we should fix that? + # get_d.addCallback(check1) + + # now fire off all the deferreds + origin_d.callback(99) + self.assertEqual(self.successResultOf(origin_d), 99) + self.assertEqual(self.successResultOf(set_d), 99) + self.assertEqual(self.successResultOf(get_d), 99) + + def test_callbacks(self): + """Invalidation callbacks are called at the right time""" + cache = DeferredCache("test") + callbacks = set() + + # start with an entry, with a callback + cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) + + # now replace that entry with a pending result + origin_d = defer.Deferred() + set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) + + # ... and also make a get request + get_d = cache.get("k1", callback=lambda: callbacks.add("get")) + + # we don't expect the invalidation callback for the original value to have + # been called yet, even though get() will now return a different result. + # I'm not sure if that is by design or not. + self.assertEqual(callbacks, set()) + + # now fire off all the deferreds + origin_d.callback(20) + self.assertEqual(self.successResultOf(set_d), 20) + self.assertEqual(self.successResultOf(get_d), 20) + + # now the original invalidation callback should have been called, but none of + # the others + self.assertEqual(callbacks, {"prefill"}) + callbacks.clear() + + # another update should invalidate both the previous results + cache.prefill("k1", 30) + self.assertEqual(callbacks, {"set", "get"}) + + def test_set_fail(self): + cache = DeferredCache("test") + callbacks = set() + + # start with an entry, with a callback + cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) + + # now replace that entry with a pending result + origin_d = defer.Deferred() + set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) + + # ... and also make a get request + get_d = cache.get("k1", callback=lambda: callbacks.add("get")) + + # none of the callbacks should have been called yet + self.assertEqual(callbacks, set()) + + # oh noes! fails! + e = Exception("oops") + origin_d.errback(e) + self.assertIs(self.failureResultOf(set_d, Exception).value, e) + self.assertIs(self.failureResultOf(get_d, Exception).value, e) + + # the callbacks for the failed requests should have been called. + # I'm not sure if this is deliberate or not. + self.assertEqual(callbacks, {"get", "set"}) + callbacks.clear() + + # the old value should still be returned now? + get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2")) + self.assertEqual(self.successResultOf(get_d2), 10) + + # replacing the value now should run the callbacks for those requests + # which got the original result + cache.prefill("k1", 30) + self.assertEqual(callbacks, {"prefill", "get2"}) + def test_get_immediate(self): cache = DeferredCache("test") d1 = defer.Deferred() diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index fc2663c02d..2ad08f541b 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Set import mock @@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) + def test_cache_with_async_exception(self): + """The wrapped function returns a failure + """ + + class Cls: + result = None + call_count = 0 + + @cached() + def fn(self, arg1): + self.call_count += 1 + return self.result + + obj = Cls() + callbacks = set() # type: Set[str] + + # set off an asynchronous request + obj.result = origin_d = defer.Deferred() + + d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) + self.assertFalse(d1.called) + + # a second request should also return a deferred, but should not call the + # function itself. + d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2")) + self.assertFalse(d2.called) + self.assertEqual(obj.call_count, 1) + + # no callbacks yet + self.assertEqual(callbacks, set()) + + # the original request fails + e = Exception("bzz") + origin_d.errback(e) + + # ... which should cause the lookups to fail similarly + self.assertIs(self.failureResultOf(d1, Exception).value, e) + self.assertIs(self.failureResultOf(d2, Exception).value, e) + + # ... and the callbacks to have been, uh, called. + self.assertEqual(callbacks, {"d1", "d2"}) + + # ... leaving the cache empty + self.assertEqual(len(obj.fn.cache.cache), 0) + + # and a second call should work as normal + obj.result = defer.succeed(100) + d3 = obj.fn(1) + self.assertEqual(self.successResultOf(d3), 100) + self.assertEqual(obj.call_count, 2) + def test_cache_logcontexts(self): """Check that logcontexts are set and restored correctly when using the cache.""" -- cgit 1.5.1 From cbc82aa09faa59acc20865e8b5c36561acb9a570 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 30 Oct 2020 11:43:17 +0000 Subject: Implement and use an @lru_cache decorator (#8595) We don't always need the full power of a DeferredCache. --- changelog.d/8595.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 37 +++-- synapse/util/caches/descriptors.py | 235 ++++++++++++++++++++++++------- tests/util/caches/test_descriptors.py | 60 +++++++- 4 files changed, 272 insertions(+), 61 deletions(-) create mode 100644 changelog.d/8595.misc (limited to 'tests/util') diff --git a/changelog.d/8595.misc b/changelog.d/8595.misc new file mode 100644 index 0000000000..24fab65cda --- /dev/null +++ b/changelog.d/8595.misc @@ -0,0 +1 @@ +Implement and use an @lru_cache decorator. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d9b5478b53..82a72dc34f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -15,8 +15,8 @@ # limitations under the License. import logging -from collections import namedtuple +import attr from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, RelationTypes @@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import lru_cache +from synapse.util.caches.lrucache import LruCache from .push_rule_evaluator import PushRuleEvaluatorForEvent @@ -120,7 +121,7 @@ class BulkPushRuleEvaluator: dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = await self._get_rules_for_room(room_id) + rules_for_room = self._get_rules_for_room(room_id) rules_by_user = await rules_for_room.get_rules(event, context) @@ -138,7 +139,7 @@ class BulkPushRuleEvaluator: return rules_by_user - @cached() + @lru_cache() def _get_rules_for_room(self, room_id): """Get the current RulesForRoom object for the given room id @@ -275,12 +276,14 @@ class RulesForRoom: the entire cache for the room. """ - def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics): + def __init__( + self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics + ): """ Args: hs (HomeServer) room_id (str) - rules_for_room_cache(Cache): The cache object that caches these + rules_for_room_cache: The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics (CacheMetric) """ @@ -489,13 +492,21 @@ class RulesForRoom: self.state_group = state_group -class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): - # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us (i.e. two _CacheContext are the same if - # their caches and keys match). This is important in particular to - # dedupe when we add callbacks to lru cache nodes, otherwise the number - # of callbacks would grow. +@attr.attrs(slots=True, frozen=True) +class _Invalidation: + # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules, + # which means that it it is stored on the bulk_get_push_rules cache entry. In order + # to ensure that we don't accumulate lots of redunant callbacks on the cache entry, + # we need to ensure that two _Invalidation objects are "equal" if they refer to the + # same `cache` and `room_id`. + # + # attrs provides suitable __hash__ and __eq__ methods, provided we remember to + # set `frozen=True`. + + cache = attr.ib(type=LruCache) + room_id = attr.ib(type=str) + def __call__(self): - rules = self.cache.get_immediate(self.room_id, None, update_metrics=False) + rules = self.cache.get(self.room_id, None, update_metrics=False) if rules: rules.invalidate_all() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5d7fffee66..a924140cdf 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,10 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import functools import inspect import logging -from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) from weakref import WeakValueDictionary from twisted.internet import defer @@ -24,6 +37,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]): class _CacheDescriptorBase: - def __init__(self, orig: _CachedFunction, num_args, cache_context=False): + def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): self.orig = orig arg_spec = inspect.getfullargspec(orig) @@ -97,8 +111,107 @@ class _CacheDescriptorBase: self.add_cache_context = cache_context + self.cache_key_builder = get_cache_key_builder( + self.arg_names, self.arg_defaults + ) + + +class _LruCachedFunction(Generic[F]): + cache = None # type: LruCache[CacheKey, Any] + __call__ = None # type: F + + +def lru_cache( + max_entries: int = 1000, cache_context: bool = False, +) -> Callable[[F], _LruCachedFunction[F]]: + """A method decorator that applies a memoizing cache around the function. + + This is more-or-less a drop-in equivalent to functools.lru_cache, although note + that the signature is slightly different. + + The main differences with functools.lru_cache are: + (a) the size of the cache can be controlled via the cache_factor mechanism + (b) the wrapped function can request a "cache_context" which provides a + callback mechanism to indicate that the result is no longer valid + (c) prometheus metrics are exposed automatically. + + The function should take zero or more arguments, which are used as the key for the + cache. Single-argument functions use that argument as the cache key; otherwise the + arguments are built into a tuple. + + Cached functions can be "chained" (i.e. a cached function can call other cached + functions and get appropriately invalidated when they called caches are + invalidated) by adding a special "cache_context" argument to the function + and passing that as a kwarg to all caches called. For example: + + @lru_cache(cache_context=True) + def foo(self, key, cache_context): + r1 = self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = self.bar2(key, on_invalidate=cache_context.invalidate) + return r1 + r2 + + The wrapped function also has a 'cache' property which offers direct access to the + underlying LruCache. + """ + + def func(orig: F) -> _LruCachedFunction[F]: + desc = LruCacheDescriptor( + orig, max_entries=max_entries, cache_context=cache_context, + ) + return cast(_LruCachedFunction[F], desc) + + return func + + +class LruCacheDescriptor(_CacheDescriptorBase): + """Helper for @lru_cache""" + + class _Sentinel(enum.Enum): + sentinel = object() + + def __init__( + self, orig, max_entries: int = 1000, cache_context: bool = False, + ): + super().__init__(orig, num_args=None, cache_context=cache_context) + self.max_entries = max_entries + + def __get__(self, obj, owner): + cache = LruCache( + cache_name=self.orig.__name__, max_size=self.max_entries, + ) # type: LruCache[CacheKey, Any] + + get_cache_key = self.cache_key_builder + sentinel = LruCacheDescriptor._Sentinel.sentinel + + @functools.wraps(self.orig) + def _wrapped(*args, **kwargs): + invalidate_callback = kwargs.pop("on_invalidate", None) + callbacks = (invalidate_callback,) if invalidate_callback else () + + cache_key = get_cache_key(args, kwargs) -class CacheDescriptor(_CacheDescriptorBase): + ret = cache.get(cache_key, default=sentinel, callbacks=callbacks) + if ret != sentinel: + return ret + + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) + + ret2 = self.orig(obj, *args, **kwargs) + cache.set(cache_key, ret2, callbacks=callbacks) + + return ret2 + + wrapped = cast(_CachedFunction, _wrapped) + wrapped.cache = cache + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +class DeferredCacheDescriptor(_CacheDescriptorBase): """ A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that @@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=False, iterable=False, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) self.max_entries = max_entries @@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase): iterable=self.iterable, ) # type: DeferredCache[CacheKey, Any] - def get_cache_key_gen(args, kwargs): - """Given some args/kwargs return a generator that resolves into - the cache_key. - - We loop through each arg name, looking up if its in the `kwargs`, - otherwise using the next argument in `args`. If there are no more - args then we try looking the arg name up in the defaults - """ - pos = 0 - for nm in self.arg_names: - if nm in kwargs: - yield kwargs[nm] - elif pos < len(args): - yield args[pos] - pos += 1 - else: - yield self.arg_defaults[nm] - - # By default our cache key is a tuple, but if there is only one item - # then don't bother wrapping in a tuple. This is to save memory. - if self.num_args == 1: - nm = self.arg_names[0] - - def get_cache_key(args, kwargs): - if nm in kwargs: - return kwargs[nm] - elif len(args): - return args[0] - else: - return self.arg_defaults[nm] - - else: - - def get_cache_key(args, kwargs): - return tuple(get_cache_key_gen(args, kwargs)) + get_cache_key = self.cache_key_builder @functools.wraps(self.orig) def _wrapped(*args, **kwargs): @@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase): wrapped.prefill = lambda key, val: cache.prefill(key[0], val) else: wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_many = cache.invalidate_many wrapped.prefill = cache.prefill @@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase): return wrapped -class CacheListDescriptor(_CacheDescriptorBase): +class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes @@ -382,11 +459,13 @@ class _CacheContext: on a lower level. """ + Cache = Union[DeferredCache, LruCache] + _cache_context_objects = ( WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext] + ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext] - def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None + def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None: self._cache = cache self._cache_key = cache_key @@ -396,8 +475,8 @@ class _CacheContext: @classmethod def get_instance( - cls, cache, cache_key - ): # type: (DeferredCache, CacheKey) -> _CacheContext + cls, cache: "_CacheContext.Cache", cache_key: CacheKey + ) -> "_CacheContext": """Returns an instance constructed with the given arguments. A new instance is only created if none already exists. @@ -418,7 +497,7 @@ def cached( cache_context: bool = False, iterable: bool = False, ) -> Callable[[F], _CachedFunction[F]]: - func = lambda orig: CacheDescriptor( + func = lambda orig: DeferredCacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -460,7 +539,7 @@ def cachedList( def batch_do_something(self, first_arg, second_args): ... """ - func = lambda orig: CacheListDescriptor( + func = lambda orig: DeferredCacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, @@ -468,3 +547,65 @@ def cachedList( ) return cast(Callable[[F], _CachedFunction[F]], func) + + +def get_cache_key_builder( + param_names: Sequence[str], param_defaults: Mapping[str, Any] +) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: + """Construct a function which will build cache keys suitable for a cached function + + Args: + param_names: list of formal parameter names for the cached function + param_defaults: a mapping from parameter name to default value for that param + + Returns: + A function which will take an (args, kwargs) pair and return a cache key + """ + + # By default our cache key is a tuple, but if there is only one item + # then don't bother wrapping in a tuple. This is to save memory. + + if len(param_names) == 1: + nm = param_names[0] + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + if nm in kwargs: + return kwargs[nm] + elif len(args): + return args[0] + else: + return param_defaults[nm] + + else: + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + + return get_cache_key + + +def _get_cache_key_gen( + param_names: Iterable[str], + param_defaults: Mapping[str, Any], + args: Sequence[Any], + kwargs: Mapping[str, Any], +) -> Iterable[Any]: + """Given some args/kwargs return a generator that resolves into + the cache_key. + + This is essentially the same operation as `inspect.getcallargs`, but optimised so + that we don't need to inspect the target function for each call. + """ + + # We loop through each arg name, looking up if its in the `kwargs`, + # otherwise using the next argument in `args`. If there are no more + # args then we try looking the arg name up in the defaults. + pos = 0 + for nm in param_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield param_defaults[nm] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 2ad08f541b..cf1e3203a4 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -29,13 +29,46 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, lru_cache from tests import unittest +from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) +class LruCacheDecoratorTestCase(unittest.TestCase): + def test_base(self): + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @lru_cache() + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + obj.mock.return_value = "fish" + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, 3) + obj.mock.reset_mock() + + # the two values should now be cached + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) @@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) + def test_invalidate_cascade(self): + """Invalidations should cascade up through cache contexts""" + + class Cls: + @cached(cache_context=True) + async def func1(self, key, cache_context): + return await self.func2(key, on_invalidate=cache_context.invalidate) + + @cached(cache_context=True) + async def func2(self, key, cache_context): + return self.func3(key, on_invalidate=cache_context.invalidate) + + @lru_cache(cache_context=True) + def func3(self, key, cache_context): + self.invalidate = cache_context.invalidate + return 42 + + obj = Cls() + + top_invalidate = mock.Mock() + r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate)) + self.assertEqual(r, 42) + obj.invalidate() + top_invalidate.assert_called_once() + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached -- cgit 1.5.1