From 1781bbe319ce24e8e468f0422519dc5823d8d420 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 9 Oct 2020 11:35:11 -0400 Subject: Add type hints to response cache. (#8507) --- synapse/util/caches/response_cache.py | 50 ++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 22 deletions(-) (limited to 'synapse/util/caches') diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index df1a721add..32228f42ee 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar from twisted.internet import defer @@ -20,10 +21,15 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) +T = TypeVar("T") + -class ResponseCache: +class ResponseCache(Generic[T]): """ This caches a deferred response. Until the deferred completes it will be returned from the cache. This means that if the client retries the request @@ -31,8 +37,9 @@ class ResponseCache: used rather than trying to compute a new response. """ - def __init__(self, hs, name, timeout_ms=0): - self.pending_result_cache = {} # Requests that haven't finished yet. + def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): + # Requests that haven't finished yet. + self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000.0 @@ -40,13 +47,13 @@ class ResponseCache: self._name = name self._metrics = register_cache("response_cache", name, self, resizable=False) - def size(self): + def size(self) -> int: return len(self.pending_result_cache) - def __len__(self): + def __len__(self) -> int: return self.size() - def get(self, key): + def get(self, key: T) -> Optional[defer.Deferred]: """Look up the given key. Can return either a new Deferred (which also doesn't follow the synapse @@ -58,12 +65,11 @@ class ResponseCache: from an absent cache entry. Args: - key (hashable): + key: key to get/set in the cache Returns: - twisted.internet.defer.Deferred|None|E: None if there is no entry - for this key; otherwise either a deferred result or the result - itself. + None if there is no entry for this key; otherwise a deferred which + resolves to the result. """ result = self.pending_result_cache.get(key) if result is not None: @@ -73,7 +79,7 @@ class ResponseCache: self._metrics.inc_misses() return None - def set(self, key, deferred): + def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred: """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, @@ -85,12 +91,11 @@ class ResponseCache: result. You will probably want to make_deferred_yieldable the result. Args: - key (hashable): - deferred (twisted.internet.defer.Deferred[T): + key: key to get/set in the cache + deferred: The deferred which resolves to the result. Returns: - twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual - result. + A new deferred which resolves to the actual result. """ result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result @@ -107,7 +112,9 @@ class ResponseCache: result.addBoth(remove) return result.observe() - def wrap(self, key, callback, *args, **kwargs): + def wrap( + self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any + ) -> defer.Deferred: """Wrap together a *get* and *set* call, taking care of logcontexts First looks up the key in the cache, and if it is present makes it @@ -118,21 +125,20 @@ class ResponseCache: Example usage: - @defer.inlineCallbacks - def handle_request(request): + async def handle_request(request): # etc return result - result = yield response_cache.wrap( + result = await response_cache.wrap( key, handle_request, request, ) Args: - key (hashable): key to get/set in the cache + key: key to get/set in the cache - callback (callable): function to call if the key is not found in + callback: function to call if the key is not found in the cache *args: positional parameters to pass to the callback, if it is used @@ -140,7 +146,7 @@ class ResponseCache: **kwargs: named parameters to pass to the callback, if it is used Returns: - twisted.internet.defer.Deferred: yieldable result + Deferred which resolves to the result """ result = self.get(key) if not result: -- cgit 1.5.1 From 7eff59ec91e59140c375b43a6dac05b833ab0051 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 19:40:53 +0100 Subject: Add some more type annotations to Cache --- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/util/caches/descriptors.py | 81 ++++++++++++++++++------- synapse/util/caches/lrucache.py | 3 +- 3 files changed, 62 insertions(+), 24 deletions(-) (limited to 'synapse/util/caches') diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 1f8dafe7ea..273d627fad 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -26,7 +26,7 @@ class SlavedClientIpStore(BaseSlavedStore): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 - ) + ) # type: Cache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 98b34f2223..14458bc20f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,12 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import functools import inspect import logging import threading -from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from weakref import WeakValueDictionary from prometheus_client import Gauge @@ -38,6 +49,8 @@ logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) +KT = TypeVar("KT") +VT = TypeVar("VT") class _CachedFunction(Generic[F]): @@ -61,13 +74,19 @@ cache_pending_metric = Gauge( ["name"], ) -_CacheSentinel = object() + +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() class CacheEntry: __slots__ = ["deferred", "callbacks", "invalidated"] - def __init__(self, deferred, callbacks): + def __init__( + self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] + ): self.deferred = deferred self.callbacks = set(callbacks) self.invalidated = False @@ -80,7 +99,13 @@ class CacheEntry: self.callbacks.clear() -class Cache: +class Cache(Generic[KT, VT]): + """Wraps an LruCache, adding support for Deferred results. + + It expects that each entry added with set() will be a Deferred; likewise get() + may return an ObservableDeferred. + """ + __slots__ = ( "cache", "name", @@ -103,19 +128,23 @@ class Cache: Args: name: The name of the cache max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key + keylen: The length of the tuple used as the cache key. Ignored unless + `tree` is True. tree: Use a TreeCache instead of a dict as the underlying cache type iterable: If True, count each item in the cached object as an entry, rather than each cached object apply_cache_factor_from_config: Whether cache factors specified in the config file affect `max_entries` - - Returns: - Cache """ cache_type = TreeCache if tree else dict - self._pending_deferred_cache = cache_type() + # _pending_deferred_cache maps from the key value to a `CacheEntry` object. + self._pending_deferred_cache = ( + cache_type() + ) # type: MutableMapping[KT, CacheEntry] + + # cache is used for completed results and maps to the result itself, rather than + # a Deferred. self.cache = LruCache( max_size=max_entries, keylen=keylen, @@ -155,7 +184,13 @@ class Cache: "Cache objects can only be accessed from the main thread" ) - def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True): + def get( + self, + key: KT, + default=_Sentinel.sentinel, + callback: Optional[Callable[[], None]] = None, + update_metrics: bool = True, + ): """Looks the key up in the caches. Args: @@ -166,30 +201,32 @@ class Cache: update_metrics (bool): whether to update the cache hit rate metrics Returns: - Either an ObservableDeferred or the raw result + Either an ObservableDeferred or the result itself """ callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _CacheSentinel) - if val is not _CacheSentinel: + val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if val is not _Sentinel.sentinel: val.callbacks.update(callbacks) if update_metrics: self.metrics.inc_hits() return val.deferred - val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) - if val is not _CacheSentinel: + val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) + if val is not _Sentinel.sentinel: self.metrics.inc_hits() return val if update_metrics: self.metrics.inc_misses() - if default is _CacheSentinel: + if default is _Sentinel.sentinel: raise KeyError() else: return default - def set(self, key, value, callback=None): + def set( + self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None + ) -> ObservableDeferred: if not isinstance(value, defer.Deferred): raise TypeError("not a Deferred") @@ -248,7 +285,7 @@ class Cache: observer.addCallbacks(cb, eb) return observable - def prefill(self, key, value, callback=None): + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) @@ -267,7 +304,7 @@ class Cache: if entry: entry.invalidate() - def invalidate_many(self, key): + def invalidate_many(self, key: KT): self.check_thread() if not isinstance(key, tuple): raise TypeError("The cache key must be a tuple not %r" % (type(key),)) @@ -275,7 +312,7 @@ class Cache: # if we have a pending lookup for this key, remove it from the # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(key, None) + entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() @@ -396,7 +433,7 @@ class CacheDescriptor(_CacheDescriptorBase): keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) + ) # type: Cache[Tuple, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 4bc1a67b58..33eae2b7c4 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -64,7 +64,8 @@ class LruCache: Args: max_size: The maximum amount of entries the cache can hold - keylen: The length of the tuple used as the cache key + keylen: The length of the tuple used as the cache key. Ignored unless + cache_type is `TreeCache`. cache_type (type): type of underlying cache to be used. Typically one of dict -- cgit 1.5.1 From 9f87da0a84f93096e228f01f1139c9b5db8ea3d4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 19:43:37 +0100 Subject: Rename Cache->DeferredCache --- synapse/replication/slave/storage/client_ips.py | 6 +++--- synapse/storage/databases/main/client_ips.py | 4 ++-- synapse/storage/databases/main/devices.py | 4 ++-- synapse/storage/databases/main/events_worker.py | 4 ++-- synapse/util/caches/descriptors.py | 19 ++++++++++++------- tests/storage/test__base.py | 10 +++++----- tests/test_metrics.py | 4 ++-- tests/util/caches/test_descriptors.py | 4 ++-- 8 files changed, 30 insertions(+), 25 deletions(-) (limited to 'synapse/util/caches') diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 273d627fad..40ea78a353 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import DeferredCache from ._base import BaseSlavedStore @@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = Cache( + self.client_ip_last_seen = DeferredCache( name="client_ip_last_seen", keylen=4, max_entries=50000 - ) # type: Cache[tuple, int] + ) # type: DeferredCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index a25a888443..ad32701d36 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import DeferredCache logger = logging.getLogger(__name__) @@ -410,7 +410,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - self.client_ip_last_seen = Cache( + self.client_ip_last_seen = DeferredCache( name="client_ip_last_seen", keylen=4, max_entries=50000 ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 88fd97e1df..d903155e89 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,7 +34,7 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.descriptors import Cache, cached, cachedList +from synapse.util.caches.descriptors import DeferredCache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -1004,7 +1004,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = Cache( + self.device_id_exists_cache = DeferredCache( name="device_id_exists", keylen=2, max_entries=10000 ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3ec4d1d9c2..be7f60f2e8 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached +from synapse.util.caches.descriptors import DeferredCache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -145,7 +145,7 @@ class EventsWorkerStore(SQLBaseStore): self._cleanup_old_transaction_ids, ) - self._get_event_cache = Cache( + self._get_event_cache = DeferredCache( "*getEvent*", keylen=3, max_entries=hs.config.caches.event_cache_size, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 14458bc20f..7c9fe199bf 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -99,7 +99,7 @@ class CacheEntry: self.callbacks.clear() -class Cache(Generic[KT, VT]): +class DeferredCache(Generic[KT, VT]): """Wraps an LruCache, adding support for Deferred results. It expects that each entry added with set() will be a Deferred; likewise get() @@ -225,7 +225,10 @@ class Cache(Generic[KT, VT]): return default def set( - self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None + self, + key: KT, + value: defer.Deferred, + callback: Optional[Callable[[], None]] = None, ) -> ObservableDeferred: if not isinstance(value, defer.Deferred): raise TypeError("not a Deferred") @@ -427,13 +430,13 @@ class CacheDescriptor(_CacheDescriptorBase): self.iterable = iterable def __get__(self, obj, owner): - cache = Cache( + cache = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) # type: Cache[Tuple, Any] + ) # type: DeferredCache[Tuple, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into @@ -677,9 +680,9 @@ class _CacheContext: _cache_context_objects = ( WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] + ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext] - def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None + def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None self._cache = cache self._cache_key = cache_key @@ -688,7 +691,9 @@ class _CacheContext: self._cache.invalidate(self._cache_key) @classmethod - def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext + def get_instance( + cls, cache, cache_key + ): # type: (DeferredCache, CacheKey) -> _CacheContext """Returns an instance constructed with the given arguments. A new instance is only created if none already exists. diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index f5afed017c..00adcab7b9 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,14 +20,14 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import Cache, cached +from synapse.util.caches.descriptors import DeferredCache, cached from tests import unittest -class CacheTestCase(unittest.HomeserverTestCase): +class DeferredCacheTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.cache = Cache("test") + self.cache = DeferredCache("test") def test_empty(self): failed = False @@ -56,7 +56,7 @@ class CacheTestCase(unittest.HomeserverTestCase): self.assertTrue(failed) def test_eviction(self): - cache = Cache("test", max_entries=2) + cache = DeferredCache("test", max_entries=2) cache.prefill(1, "one") cache.prefill(2, "two") @@ -74,7 +74,7 @@ class CacheTestCase(unittest.HomeserverTestCase): cache.get(3) def test_eviction_lru(self): - cache = Cache("test", max_entries=2) + cache = DeferredCache("test", max_entries=2) cache.prefill(1, "one") cache.prefill(2, "two") diff --git a/tests/test_metrics.py b/tests/test_metrics.py index f5f63d8ed6..1c03a52f7c 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import DeferredCache from tests import unittest @@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase): Caches produce metrics reflecting their state when scraped. """ CACHE_NAME = "cache_metrics_test_fgjkbdfg" - cache = Cache(CACHE_NAME, max_entries=777) + cache = DeferredCache(CACHE_NAME, max_entries=777) items = { x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 677e925477..bd870b4a33 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -42,9 +42,9 @@ def run_on_reactor(): return make_deferred_yieldable(d) -class CacheTestCase(unittest.TestCase): +class DeferredCacheTestCase(unittest.TestCase): def test_invalidate_all(self): - cache = descriptors.Cache("testcache") + cache = descriptors.DeferredCache("testcache") callback_record = [False, False] -- cgit 1.5.1 From 4182bb812f21d7231ff0efc5e93e5f2e88f6605e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 23:25:23 +0100 Subject: move DeferredCache into its own module --- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/storage/databases/main/client_ips.py | 2 +- synapse/storage/databases/main/devices.py | 3 +- synapse/storage/databases/main/events_worker.py | 3 +- synapse/util/caches/deferred_cache.py | 292 ++++++++++++++++++++++++ synapse/util/caches/descriptors.py | 284 +---------------------- tests/storage/test__base.py | 3 +- tests/test_metrics.py | 2 +- tests/util/caches/test_deferred_cache.py | 64 ++++++ tests/util/caches/test_descriptors.py | 44 ---- 10 files changed, 367 insertions(+), 332 deletions(-) create mode 100644 synapse/util/caches/deferred_cache.py create mode 100644 tests/util/caches/test_deferred_cache.py (limited to 'synapse/util/caches') diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 40ea78a353..4b0ea0cc01 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache from ._base import BaseSlavedStore diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index ad32701d36..9e66e6648a 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d903155e89..e662a20d24 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,7 +34,8 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.descriptors import DeferredCache, cached, cachedList +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index be7f60f2e8..ff150f0be7 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import DeferredCache, cached +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py new file mode 100644 index 0000000000..f728cd2cf2 --- /dev/null +++ b/synapse/util/caches/deferred_cache.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import threading +from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast + +from prometheus_client import Gauge + +from twisted.internet import defer + +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches import register_cache +from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry + +cache_pending_metric = Gauge( + "synapse_util_caches_cache_pending", + "Number of lookups currently pending for this cache", + ["name"], +) + + +KT = TypeVar("KT") +VT = TypeVar("VT") + + +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + +class DeferredCache(Generic[KT, VT]): + """Wraps an LruCache, adding support for Deferred results. + + It expects that each entry added with set() will be a Deferred; likewise get() + may return an ObservableDeferred. + """ + + __slots__ = ( + "cache", + "name", + "keylen", + "thread", + "metrics", + "_pending_deferred_cache", + ) + + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key. Ignored unless + `tree` is True. + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + """ + cache_type = TreeCache if tree else dict + + # _pending_deferred_cache maps from the key value to a `CacheEntry` object. + self._pending_deferred_cache = ( + cache_type() + ) # type: MutableMapping[KT, CacheEntry] + + # cache is used for completed results and maps to the result itself, rather than + # a Deferred. + self.cache = LruCache( + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, + size_callback=(lambda d: len(d)) if iterable else None, + evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, + ) + + self.name = name + self.keylen = keylen + self.thread = None # type: Optional[threading.Thread] + self.metrics = register_cache( + "cache", + name, + self.cache, + collect_callback=self._metrics_collection_callback, + ) + + @property + def max_entries(self): + return self.cache.max_size + + def _on_evicted(self, evicted_count): + self.metrics.inc_evictions(evicted_count) + + def _metrics_collection_callback(self): + cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get( + self, + key: KT, + default=_Sentinel.sentinel, + callback: Optional[Callable[[], None]] = None, + update_metrics: bool = True, + ): + """Looks the key up in the caches. + + Args: + key(tuple) + default: What is returned if key is not in the caches. If not + specified then function throws KeyError instead + callback(fn): Gets called when the entry in the cache is invalidated + update_metrics (bool): whether to update the cache hit rate metrics + + Returns: + Either an ObservableDeferred or the result itself + """ + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if val is not _Sentinel.sentinel: + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) + if val is not _Sentinel.sentinel: + self.metrics.inc_hits() + return val + + if update_metrics: + self.metrics.inc_misses() + + if default is _Sentinel.sentinel: + raise KeyError() + else: + return default + + def set( + self, + key: KT, + value: defer.Deferred, + callback: Optional[Callable[[], None]] = None, + ) -> ObservableDeferred: + if not isinstance(value, defer.Deferred): + raise TypeError("not a Deferred") + + callbacks = [callback] if callback else [] + self.check_thread() + observable = ObservableDeferred(value, consumeErrors=True) + observer = observable.observe() + entry = CacheEntry(deferred=observable, callbacks=callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def compare_and_pop(): + """Check if our entry is still the one in _pending_deferred_cache, and + if so, pop it. + + Returns true if the entries matched. + """ + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + return True + + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + return False + + def cb(result): + if compare_and_pop(): + self.cache.set(key, result, entry.callbacks) + else: + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. + entry.invalidate() + + def eb(_fail): + compare_and_pop() + entry.invalidate() + + # once the deferred completes, we can move the entry from the + # _pending_deferred_cache to the real cache. + # + observer.addCallbacks(cb, eb) + return observable + + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) + + def invalidate(self, key): + self.check_thread() + self.cache.pop(key, None) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. + entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. + if entry: + entry.invalidate() + + def invalidate_many(self, key: KT): + self.check_thread() + if not isinstance(key, tuple): + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + self.cache.del_multi(key) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above + entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) + if entry_dict is not None: + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + + def invalidate_all(self): + self.check_thread() + self.cache.clear() + for entry in self._pending_deferred_cache.values(): + entry.invalidate() + self._pending_deferred_cache.clear() + + +class CacheEntry: + __slots__ = ["deferred", "callbacks", "invalidated"] + + def __init__( + self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] + ): + self.deferred = deferred + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 7c9fe199bf..1f43886804 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,44 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import functools import inspect import logging -import threading -from typing import ( - Any, - Callable, - Generic, - Iterable, - MutableMapping, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary -from prometheus_client import Gauge - from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry - -from . import register_cache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) -KT = TypeVar("KT") -VT = TypeVar("VT") class _CachedFunction(Generic[F]): @@ -68,266 +48,6 @@ class _CachedFunction(Generic[F]): __call__ = None # type: F -cache_pending_metric = Gauge( - "synapse_util_caches_cache_pending", - "Number of lookups currently pending for this cache", - ["name"], -) - - -class _Sentinel(enum.Enum): - # defining a sentinel in this way allows mypy to correctly handle the - # type of a dictionary lookup. - sentinel = object() - - -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] - - def __init__( - self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] - ): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self): - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() - - -class DeferredCache(Generic[KT, VT]): - """Wraps an LruCache, adding support for Deferred results. - - It expects that each entry added with set() will be a Deferred; likewise get() - may return an ObservableDeferred. - """ - - __slots__ = ( - "cache", - "name", - "keylen", - "thread", - "metrics", - "_pending_deferred_cache", - ) - - def __init__( - self, - name: str, - max_entries: int = 1000, - keylen: int = 1, - tree: bool = False, - iterable: bool = False, - apply_cache_factor_from_config: bool = True, - ): - """ - Args: - name: The name of the cache - max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key. Ignored unless - `tree` is True. - tree: Use a TreeCache instead of a dict as the underlying cache type - iterable: If True, count each item in the cached object as an entry, - rather than each cached object - apply_cache_factor_from_config: Whether cache factors specified in the - config file affect `max_entries` - """ - cache_type = TreeCache if tree else dict - - # _pending_deferred_cache maps from the key value to a `CacheEntry` object. - self._pending_deferred_cache = ( - cache_type() - ) # type: MutableMapping[KT, CacheEntry] - - # cache is used for completed results and maps to the result itself, rather than - # a Deferred. - self.cache = LruCache( - max_size=max_entries, - keylen=keylen, - cache_type=cache_type, - size_callback=(lambda d: len(d)) if iterable else None, - evicted_callback=self._on_evicted, - apply_cache_factor_from_config=apply_cache_factor_from_config, - ) - - self.name = name - self.keylen = keylen - self.thread = None # type: Optional[threading.Thread] - self.metrics = register_cache( - "cache", - name, - self.cache, - collect_callback=self._metrics_collection_callback, - ) - - @property - def max_entries(self): - return self.cache.max_size - - def _on_evicted(self, evicted_count): - self.metrics.inc_evictions(evicted_count) - - def _metrics_collection_callback(self): - cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get( - self, - key: KT, - default=_Sentinel.sentinel, - callback: Optional[Callable[[], None]] = None, - update_metrics: bool = True, - ): - """Looks the key up in the caches. - - Args: - key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead - callback(fn): Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics - - Returns: - Either an ObservableDeferred or the result itself - """ - callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) - if val is not _Sentinel.sentinel: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred - - val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) - if val is not _Sentinel.sentinel: - self.metrics.inc_hits() - return val - - if update_metrics: - self.metrics.inc_misses() - - if default is _Sentinel.sentinel: - raise KeyError() - else: - return default - - def set( - self, - key: KT, - value: defer.Deferred, - callback: Optional[Callable[[], None]] = None, - ) -> ObservableDeferred: - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] - self.check_thread() - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) - - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() - - self._pending_deferred_cache[key] = entry - - def compare_and_pop(): - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result): - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail): - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) - return observable - - def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): - callbacks = [callback] if callback else [] - self.cache.set(key, value, callbacks=callbacks) - - def invalidate(self, key): - self.check_thread() - self.cache.pop(key, None) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry - # in self.cache. - entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. - if entry: - entry.invalidate() - - def invalidate_many(self, key: KT): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError("The cache key must be a tuple not %r" % (type(key),)) - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) - if entry_dict is not None: - for entry in iterate_tree_cache_entry(entry_dict): - entry.invalidate() - - def invalidate_all(self): - self.check_thread() - self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() - self._pending_deferred_cache.clear() - - class _CacheDescriptorBase: def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 00adcab7b9..2598dbe0a7 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,7 +20,8 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import DeferredCache, cached +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached from tests import unittest diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1c03a52f7c..759e4cd048 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache from tests import unittest diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py new file mode 100644 index 0000000000..9b6acdfc43 --- /dev/null +++ b/tests/util/caches/test_deferred_cache.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +from twisted.internet import defer + +import synapse.util.caches.deferred_cache + + +class DeferredCacheTestCase(unittest.TestCase): + def test_invalidate_all(self): + cache = synapse.util.caches.deferred_cache.DeferredCache("testcache") + + callback_record = [False, False] + + def record_callback(idx): + callback_record[idx] = True + + # add a couple of pending entries + d1 = defer.Deferred() + cache.set("key1", d1, partial(record_callback, 0)) + + d2 = defer.Deferred() + cache.set("key2", d2, partial(record_callback, 1)) + + # lookup should return observable deferreds + self.assertFalse(cache.get("key1").has_called()) + self.assertFalse(cache.get("key2").has_called()) + + # let one of the lookups complete + d2.callback("result2") + + # for now at least, the cache will return real results rather than an + # observabledeferred + self.assertEqual(cache.get("key2"), "result2") + + # now do the invalidation + cache.invalidate_all() + + # lookup should return none + self.assertIsNone(cache.get("key1", None)) + self.assertIsNone(cache.get("key2", None)) + + # both callbacks should have been callbacked + self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") + self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") + + # letting the other lookup complete should do nothing + d1.callback("result1") + self.assertIsNone(cache.get("key1", None)) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index bd870b4a33..3d1f960869 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from functools import partial import mock @@ -42,49 +41,6 @@ def run_on_reactor(): return make_deferred_yieldable(d) -class DeferredCacheTestCase(unittest.TestCase): - def test_invalidate_all(self): - cache = descriptors.DeferredCache("testcache") - - callback_record = [False, False] - - def record_callback(idx): - callback_record[idx] = True - - # add a couple of pending entries - d1 = defer.Deferred() - cache.set("key1", d1, partial(record_callback, 0)) - - d2 = defer.Deferred() - cache.set("key2", d2, partial(record_callback, 1)) - - # lookup should return observable deferreds - self.assertFalse(cache.get("key1").has_called()) - self.assertFalse(cache.get("key2").has_called()) - - # let one of the lookups complete - d2.callback("result2") - - # for now at least, the cache will return real results rather than an - # observabledeferred - self.assertEqual(cache.get("key2"), "result2") - - # now do the invalidation - cache.invalidate_all() - - # lookup should return none - self.assertIsNone(cache.get("key1", None)) - self.assertIsNone(cache.get("key2", None)) - - # both callbacks should have been callbacked - self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") - self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") - - # letting the other lookup complete should do nothing - d1.callback("result1") - self.assertIsNone(cache.get("key1", None)) - - class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): -- cgit 1.5.1 From 8075504a600b47ac93faf9b605ade691ae0fbcd3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 15 Oct 2020 11:44:39 +0100 Subject: Enable mypy for synapse.util.caches (#8547) This seemed to entail dragging in a type stub for SortedList. --- changelog.d/8547.misc | 1 + mypy.ini | 4 +- stubs/sortedcontainers/__init__.pyi | 11 +-- stubs/sortedcontainers/sortedlist.pyi | 177 ++++++++++++++++++++++++++++++++++ synapse/util/caches/ttlcache.py | 2 +- 5 files changed, 185 insertions(+), 10 deletions(-) create mode 100644 changelog.d/8547.misc create mode 100644 stubs/sortedcontainers/sortedlist.pyi (limited to 'synapse/util/caches') diff --git a/changelog.d/8547.misc b/changelog.d/8547.misc new file mode 100644 index 0000000000..fafb1c8347 --- /dev/null +++ b/changelog.d/8547.misc @@ -0,0 +1 @@ +Enable mypy type checking for `synapse.util.caches`. diff --git a/mypy.ini b/mypy.ini index f08fe992a4..9748f6258c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -64,9 +64,7 @@ files = synapse/streams, synapse/types.py, synapse/util/async_helpers.py, - synapse/util/caches/descriptors.py, - synapse/util/caches/response_cache.py, - synapse/util/caches/stream_change_cache.py, + synapse/util/caches, synapse/util/metrics.py, tests/replication, tests/test_utils, diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi index 073b806d3c..fa307483fe 100644 --- a/stubs/sortedcontainers/__init__.pyi +++ b/stubs/sortedcontainers/__init__.pyi @@ -1,13 +1,12 @@ -from .sorteddict import ( - SortedDict, - SortedKeysView, - SortedItemsView, - SortedValuesView, -) +from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView +from .sortedlist import SortedKeyList, SortedList, SortedListWithKey __all__ = [ "SortedDict", "SortedKeysView", "SortedItemsView", "SortedValuesView", + "SortedKeyList", + "SortedList", + "SortedListWithKey", ] diff --git a/stubs/sortedcontainers/sortedlist.pyi b/stubs/sortedcontainers/sortedlist.pyi new file mode 100644 index 0000000000..8f6086b3ff --- /dev/null +++ b/stubs/sortedcontainers/sortedlist.pyi @@ -0,0 +1,177 @@ +# stub for SortedList. This is an exact copy of +# https://github.com/grantjenks/python-sortedcontainers/blob/a419ffbd2b1c935b09f11f0971696e537fd0c510/sortedcontainers/sortedlist.pyi +# (from https://github.com/grantjenks/python-sortedcontainers/pull/107) + +from typing import ( + Any, + Callable, + Generic, + Iterable, + Iterator, + List, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +_T = TypeVar("_T") +_SL = TypeVar("_SL", bound=SortedList) +_SKL = TypeVar("_SKL", bound=SortedKeyList) +_Key = Callable[[_T], Any] +_Repr = Callable[[], str] + +def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ... + +class SortedList(MutableSequence[_T]): + + DEFAULT_LOAD_FACTOR: int = ... + def __init__( + self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ..., + ): ... + # NB: currently mypy does not honour return type, see mypy #3307 + @overload + def __new__(cls: Type[_SL], iterable: None, key: None) -> _SL: ... + @overload + def __new__(cls: Type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ... + @overload + def __new__(cls: Type[_SL], iterable: Iterable[_T], key: None) -> _SL: ... + @overload + def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ... + @property + def key(self) -> Optional[Callable[[_T], Any]]: ... + def _reset(self, load: int) -> None: ... + def clear(self) -> None: ... + def _clear(self) -> None: ... + def add(self, value: _T) -> None: ... + def _expand(self, pos: int) -> None: ... + def update(self, iterable: Iterable[_T]) -> None: ... + def _update(self, iterable: Iterable[_T]) -> None: ... + def discard(self, value: _T) -> None: ... + def remove(self, value: _T) -> None: ... + def _delete(self, pos: int, idx: int) -> None: ... + def _loc(self, pos: int, idx: int) -> int: ... + def _pos(self, idx: int) -> int: ... + def _build_index(self) -> None: ... + def __contains__(self, value: Any) -> bool: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> List[_T]: ... + @overload + def _getitem(self, index: int) -> _T: ... + @overload + def _getitem(self, index: slice) -> List[_T]: ... + @overload + def __setitem__(self, index: int, value: _T) -> None: ... + @overload + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ... + def __iter__(self) -> Iterator[_T]: ... + def __reversed__(self) -> Iterator[_T]: ... + def __len__(self) -> int: ... + def reverse(self) -> None: ... + def islice( + self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool, + ) -> Iterator[_T]: ... + def _islice( + self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool, + ) -> Iterator[_T]: ... + def irange( + self, + minimum: Optional[int] = ..., + maximum: Optional[int] = ..., + inclusive: Tuple[bool, bool] = ..., + reverse: bool = ..., + ) -> Iterator[_T]: ... + def bisect_left(self, value: _T) -> int: ... + def bisect_right(self, value: _T) -> int: ... + def bisect(self, value: _T) -> int: ... + def _bisect_right(self, value: _T) -> int: ... + def count(self, value: _T) -> int: ... + def copy(self: _SL) -> _SL: ... + def __copy__(self: _SL) -> _SL: ... + def append(self, value: _T) -> None: ... + def extend(self, values: Iterable[_T]) -> None: ... + def insert(self, index: int, value: _T) -> None: ... + def pop(self, index: int = ...) -> _T: ... + def index( + self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + ) -> int: ... + def __add__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __iadd__(self: _SL, other: Iterable[_T]) -> _SL: ... + def __mul__(self: _SL, num: int) -> _SL: ... + def __rmul__(self: _SL, num: int) -> _SL: ... + def __imul__(self: _SL, num: int) -> _SL: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __lt__(self, other: Sequence[_T]) -> bool: ... + def __gt__(self, other: Sequence[_T]) -> bool: ... + def __le__(self, other: Sequence[_T]) -> bool: ... + def __ge__(self, other: Sequence[_T]) -> bool: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + +class SortedKeyList(SortedList[_T]): + def __init__( + self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + ) -> None: ... + def __new__( + cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ... + ) -> SortedKeyList[_T]: ... + @property + def key(self) -> Callable[[_T], Any]: ... + def clear(self) -> None: ... + def _clear(self) -> None: ... + def add(self, value: _T) -> None: ... + def _expand(self, pos: int) -> None: ... + def update(self, iterable: Iterable[_T]) -> None: ... + def _update(self, iterable: Iterable[_T]) -> None: ... + # NB: Must be T to be safely passed to self.func, yet base class imposes Any + def __contains__(self, value: _T) -> bool: ... # type: ignore + def discard(self, value: _T) -> None: ... + def remove(self, value: _T) -> None: ... + def _delete(self, pos: int, idx: int) -> None: ... + def irange( + self, + minimum: Optional[int] = ..., + maximum: Optional[int] = ..., + inclusive: Tuple[bool, bool] = ..., + reverse: bool = ..., + ): ... + def irange_key( + self, + min_key: Optional[Any] = ..., + max_key: Optional[Any] = ..., + inclusive: Tuple[bool, bool] = ..., + reserve: bool = ..., + ): ... + def bisect_left(self, value: _T) -> int: ... + def bisect_right(self, value: _T) -> int: ... + def bisect(self, value: _T) -> int: ... + def bisect_key_left(self, key: Any) -> int: ... + def _bisect_key_left(self, key: Any) -> int: ... + def bisect_key_right(self, key: Any) -> int: ... + def _bisect_key_right(self, key: Any) -> int: ... + def bisect_key(self, key: Any) -> int: ... + def count(self, value: _T) -> int: ... + def copy(self: _SKL) -> _SKL: ... + def __copy__(self: _SKL) -> _SKL: ... + def index( + self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + ) -> int: ... + def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __iadd__(self: _SKL, other: Iterable[_T]) -> _SKL: ... + def __mul__(self: _SKL, num: int) -> _SKL: ... + def __rmul__(self: _SKL, num: int) -> _SKL: ... + def __imul__(self: _SKL, num: int) -> _SKL: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + +SortedListWithKey = SortedKeyList diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 3e180cafd3..6ce2a3d12b 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -34,7 +34,7 @@ class TTLCache: self._data = {} # the _CacheEntries, sorted by expiry time - self._expiry_list = SortedList() + self._expiry_list = SortedList() # type: SortedList[_CacheEntry] self._timer = timer -- cgit 1.5.1 From 3ee17585cd095e590096683395cfb9a017eac15e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 16 Oct 2020 15:51:57 +0100 Subject: Make LruCache register its own metrics (#8561) rather than have everything that instantiates an LruCache manage metrics separately, have LruCache do it itself. --- changelog.d/8561.misc | 1 + synapse/api/auth.py | 4 +-- synapse/push/push_rule_evaluator.py | 4 +-- synapse/util/caches/__init__.py | 13 ++++++---- synapse/util/caches/deferred_cache.py | 43 ++++++++++-------------------- synapse/util/caches/dictionary_cache.py | 9 +------ synapse/util/caches/lrucache.py | 46 +++++++++++++++++++++++++-------- tests/util/test_lrucache.py | 4 +-- 8 files changed, 62 insertions(+), 62 deletions(-) create mode 100644 changelog.d/8561.misc (limited to 'synapse/util/caches') diff --git a/changelog.d/8561.misc b/changelog.d/8561.misc new file mode 100644 index 0000000000..a40dedfa8e --- /dev/null +++ b/changelog.d/8561.misc @@ -0,0 +1 @@ +Move metric registration code down into `LruCache`. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1071a0576e..eb6f418b13 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -34,7 +34,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging import opentracing as opentracing from synapse.types import StateMap, UserID -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -70,8 +69,7 @@ class Auth: self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(10000) - register_cache("cache", "token_cache", self.token_cache) + self.token_cache = LruCache(10000, "token_cache") self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 3a68ce636f..4c95b149c5 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -20,7 +20,6 @@ from typing import Any, Dict, List, Optional, Pattern, Union from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -186,8 +185,7 @@ class PushRuleEvaluatorForEvent: # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000) -register_cache("cache", "regex_push_cache", regex_cache) +regex_cache = LruCache(50000, "regex_push_cache") def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 8fc05be278..89f0b38535 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -16,7 +16,7 @@ import logging from sys import intern -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Sized import attr from prometheus_client.core import Gauge @@ -92,7 +92,7 @@ class CacheMetric: def register_cache( cache_type: str, cache_name: str, - cache, + cache: Sized, collect_callback: Optional[Callable] = None, resizable: bool = True, resize_callback: Optional[Callable] = None, @@ -100,12 +100,15 @@ def register_cache( """Register a cache object for metric collection and resizing. Args: - cache_type + cache_type: a string indicating the "type" of the cache. This is used + only for deduplication so isn't too important provided it's constant. cache_name: name of the cache - cache: cache itself + cache: cache itself, which must implement __len__(), and may optionally implement + a max_size property collect_callback: If given, a function which is called during metric collection to update additional metrics. - resizable: Whether this cache supports being resized. + resizable: Whether this cache supports being resized, in which case either + resize_callback must be provided, or the cache must support set_max_size(). resize_callback: A function which can be called to resize the cache. Returns: diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index f728cd2cf2..91fdc8142d 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -24,7 +24,6 @@ from prometheus_client import Gauge from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -54,10 +53,7 @@ class DeferredCache(Generic[KT, VT]): __slots__ = ( "cache", - "name", - "keylen", "thread", - "metrics", "_pending_deferred_cache", ) @@ -89,37 +85,27 @@ class DeferredCache(Generic[KT, VT]): cache_type() ) # type: MutableMapping[KT, CacheEntry] + def metrics_cb(): + cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) + # cache is used for completed results and maps to the result itself, rather than # a Deferred. self.cache = LruCache( max_size=max_entries, keylen=keylen, + cache_name=name, cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, - evicted_callback=self._on_evicted, + metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, ) - self.name = name - self.keylen = keylen self.thread = None # type: Optional[threading.Thread] - self.metrics = register_cache( - "cache", - name, - self.cache, - collect_callback=self._metrics_collection_callback, - ) @property def max_entries(self): return self.cache.max_size - def _on_evicted(self, evicted_count): - self.metrics.inc_evictions(evicted_count) - - def _metrics_collection_callback(self): - cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) - def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -154,21 +140,18 @@ class DeferredCache(Generic[KT, VT]): if val is not _Sentinel.sentinel: val.callbacks.update(callbacks) if update_metrics: - self.metrics.inc_hits() + m = self.cache.metrics + assert m # we always have a name, so should always have metrics + m.inc_hits() return val.deferred - val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) - if val is not _Sentinel.sentinel: - self.metrics.inc_hits() - return val - - if update_metrics: - self.metrics.inc_misses() - - if default is _Sentinel.sentinel: + val = self.cache.get( + key, default, callbacks=callbacks, update_metrics=update_metrics + ) + if val is _Sentinel.sentinel: raise KeyError() else: - return default + return val def set( self, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 8592b93689..8b426c005b 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -19,8 +19,6 @@ from collections import namedtuple from synapse.util.caches.lrucache import LruCache -from . import register_cache - logger = logging.getLogger(__name__) @@ -46,18 +44,16 @@ class DictionaryCache: """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries, size_callback=len) + self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len) self.name = name self.sequence = 0 self.thread = None - # caches_by_name[name] = self.cache class Sentinel: __slots__ = [] self.sentinel = Sentinel() - self.metrics = register_cache("dictionary", name, self.cache) def check_thread(self): expected_thread = self.thread @@ -82,8 +78,6 @@ class DictionaryCache: """ entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: - self.metrics.inc_hits() - if dict_keys is None: return DictionaryEntry( entry.full, entry.known_absent, dict(entry.value) @@ -95,7 +89,6 @@ class DictionaryCache: {k: entry.value[k] for k in dict_keys if k in entry.value}, ) - self.metrics.inc_misses() return DictionaryEntry(False, set(), {}) def invalidate(self, key): diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 33eae2b7c4..e4804f79e0 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -18,6 +18,7 @@ from functools import wraps from typing import Callable, Optional, Type, Union from synapse.config import cache as cache_config +from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache @@ -43,27 +44,29 @@ class _Node: class LruCache: """ - Least-recently-used cache. + Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. + Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. - - Can also set callbacks on objects when getting/setting which are fired - when that key gets invalidated/evicted. """ def __init__( self, max_size: int, + cache_name: Optional[str] = None, keylen: int = 1, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable] = None, - evicted_callback: Optional[Callable] = None, + metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, ): """ Args: max_size: The maximum amount of entries the cache can hold + cache_name: The name of this cache, for the prometheus metrics. If unset, + no metrics will be reported on this cache. + keylen: The length of the tuple used as the cache key. Ignored unless cache_type is `TreeCache`. @@ -73,9 +76,13 @@ class LruCache: size_callback (func(V) -> int | None): - evicted_callback (func(int)|None): - if not None, called on eviction with the size of the evicted - entry + metrics_collection_callback: + metrics collection callback. This is called early in the metrics + collection process, before any of the metrics registered with the + prometheus Registry are collected, so can be used to update any dynamic + metrics. + + Ignored if cache_name is None. apply_cache_factor_from_config (bool): If true, `max_size` will be multiplied by a cache factor derived from the homeserver config @@ -94,6 +101,19 @@ class LruCache: else: self.max_size = int(max_size) + if cache_name is not None: + metrics = register_cache( + "lru_cache", + cache_name, + self, + collect_callback=metrics_collection_callback, + ) # type: Optional[CacheMetric] + else: + metrics = None + + # this is exposed for access from outside this class + self.metrics = metrics + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -105,8 +125,8 @@ class LruCache: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) - if evicted_callback: - evicted_callback(evicted_len) + if metrics: + metrics.inc_evictions(evicted_len) def synchronized(f): @wraps(f) @@ -169,13 +189,17 @@ class LruCache: return deleted_len @synchronized - def cache_get(key, default=None, callbacks=[]): + def cache_get(key, default=None, callbacks=[], update_metrics=True): node = cache.get(key, None) if node is not None: move_node_to_front(node) node.callbacks.update(callbacks) + if update_metrics and metrics: + metrics.inc_hits() return node.value else: + if update_metrics and metrics: + metrics.inc_misses() return default @synchronized diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 0adb2174af..f12834edab 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -59,7 +59,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEquals(cache.pop("key"), None) def test_del_multi(self): - cache = LruCache(4, 2, cache_type=TreeCache) + cache = LruCache(4, keylen=2, cache_type=TreeCache) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" cache[("vehicles", "car")] = "vroom" @@ -160,7 +160,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): m2 = Mock() m3 = Mock() m4 = Mock() - cache = LruCache(4, 2, cache_type=TreeCache) + cache = LruCache(4, keylen=2, cache_type=TreeCache) cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "2"), "value", callbacks=[m2]) -- cgit 1.5.1 From 0ec0bc3886bd72bdf2f64d455a7d777f4573a4f1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 16 Oct 2020 15:56:39 +0100 Subject: type annotations for LruCache --- synapse/api/auth.py | 4 +- synapse/push/push_rule_evaluator.py | 16 ++++---- synapse/util/caches/deferred_cache.py | 5 ++- synapse/util/caches/dictionary_cache.py | 22 ++++++---- synapse/util/caches/lrucache.py | 73 +++++++++++++++++++++++++++------ 5 files changed, 89 insertions(+), 31 deletions(-) (limited to 'synapse/util/caches') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index eb6f418b13..bff87fabde 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -69,7 +69,9 @@ class Auth: self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(10000, "token_cache") + self.token_cache = LruCache( + 10000, "token_cache" + ) # type: LruCache[str, Tuple[str, bool]] self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 4c95b149c5..854ffd625e 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -16,7 +16,7 @@ import logging import re -from typing import Any, Dict, List, Optional, Pattern, Union +from typing import Any, Dict, List, Optional, Pattern, Tuple, Union from synapse.events import EventBase from synapse.types import UserID @@ -173,19 +173,21 @@ class PushRuleEvaluatorForEvent: # Similar to _glob_matches, but do not treat display_name as a glob. r = regex_cache.get((display_name, False, True), None) if not r: - r = re.escape(display_name) - r = _re_word_boundary(r) - r = re.compile(r, flags=re.IGNORECASE) + r1 = re.escape(display_name) + r1 = _re_word_boundary(r1) + r = re.compile(r1, flags=re.IGNORECASE) regex_cache[(display_name, False, True)] = r - return r.search(body) + return bool(r.search(body)) def _get_value(self, dotted_key: str) -> Optional[str]: return self._value_cache.get(dotted_key, None) # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000, "regex_push_cache") +regex_cache = LruCache( + 50000, "regex_push_cache" +) # type: LruCache[Tuple[str, bool, bool],Pattern] def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: @@ -203,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: if not r: r = _glob_to_re(glob, word_boundary) regex_cache[(glob, True, word_boundary)] = r - return r.search(value) + return bool(r.search(value)) except re.error: logger.warning("Failed to parse glob to regex: %r", glob) return False diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 91fdc8142d..4026e1f8fa 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -98,7 +98,7 @@ class DeferredCache(Generic[KT, VT]): size_callback=(lambda d: len(d)) if iterable else None, metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, - ) + ) # type: LruCache[KT, VT] self.thread = None # type: Optional[threading.Thread] @@ -240,11 +240,12 @@ class DeferredCache(Generic[KT, VT]): self.check_thread() if not isinstance(key, tuple): raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + key = cast(KT, key) self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) + entry_dict = self._pending_deferred_cache.pop(key, None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 8b426c005b..588d2d49f2 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import logging import threading from collections import namedtuple +from typing import Any from synapse.util.caches.lrucache import LruCache @@ -38,23 +39,26 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va return len(self.value) +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + class DictionaryCache: """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len) + self.cache = LruCache( + max_size=max_entries, cache_name=name, size_callback=len + ) # type: LruCache[Any, DictionaryEntry] self.name = name self.sequence = 0 self.thread = None - class Sentinel: - __slots__ = [] - - self.sentinel = Sentinel() - def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -76,8 +80,8 @@ class DictionaryCache: Returns: DictionaryEntry """ - entry = self.cache.get(key, self.sentinel) - if entry is not self.sentinel: + entry = self.cache.get(key, _Sentinel.sentinel) + if entry is not _Sentinel.sentinel: if dict_keys is None: return DictionaryEntry( entry.full, entry.known_absent, dict(entry.value) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index e4804f79e0..0eed53d3f4 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -15,12 +15,30 @@ import threading from functools import wraps -from typing import Callable, Optional, Type, Union +from typing import ( + Any, + Callable, + Generic, + Iterable, + Optional, + Type, + TypeVar, + Union, + cast, + overload, +) + +from typing_extensions import Literal from synapse.config import cache as cache_config from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache +T = TypeVar("T") +FT = TypeVar("FT", bound=Callable[..., Any]) +KT = TypeVar("KT") +VT = TypeVar("VT") + def enumerate_leaves(node, depth): if depth == 0: @@ -42,7 +60,7 @@ class _Node: self.callbacks = callbacks -class LruCache: +class LruCache(Generic[KT, VT]): """ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. @@ -128,13 +146,13 @@ class LruCache: if metrics: metrics.inc_evictions(evicted_len) - def synchronized(f): + def synchronized(f: FT) -> FT: @wraps(f) def inner(*args, **kwargs): with lock: return f(*args, **kwargs) - return inner + return cast(FT, inner) cached_cache_len = [0] if size_callback is not None: @@ -188,8 +206,31 @@ class LruCache: node.callbacks.clear() return deleted_len + @overload + def cache_get( + key: KT, + default: Literal[None] = None, + callbacks: Iterable[Callable[[], None]] = ..., + update_metrics: bool = ..., + ) -> Optional[VT]: + ... + + @overload + def cache_get( + key: KT, + default: T, + callbacks: Iterable[Callable[[], None]] = ..., + update_metrics: bool = ..., + ) -> Union[T, VT]: + ... + @synchronized - def cache_get(key, default=None, callbacks=[], update_metrics=True): + def cache_get( + key: KT, + default=None, + callbacks: Iterable[Callable[[], None]] = [], + update_metrics: bool = True, + ): node = cache.get(key, None) if node is not None: move_node_to_front(node) @@ -203,7 +244,7 @@ class LruCache: return default @synchronized - def cache_set(key, value, callbacks=[]): + def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []): node = cache.get(key, None) if node is not None: # We sometimes store large objects, e.g. dicts, which cause @@ -232,7 +273,7 @@ class LruCache: evict() @synchronized - def cache_set_default(key, value): + def cache_set_default(key: KT, value: VT) -> VT: node = cache.get(key, None) if node is not None: return node.value @@ -241,8 +282,16 @@ class LruCache: evict() return value + @overload + def cache_pop(key: KT, default: Literal[None] = None) -> Union[None, VT]: + ... + + @overload + def cache_pop(key: KT, default: T) -> Union[T, VT]: + ... + @synchronized - def cache_pop(key, default=None): + def cache_pop(key: KT, default=None): node = cache.get(key, None) if node: delete_node(node) @@ -252,18 +301,18 @@ class LruCache: return default @synchronized - def cache_del_multi(key): + def cache_del_multi(key: KT) -> None: """ This will only work if constructed with cache_type=TreeCache """ popped = cache.pop(key) if popped is None: return - for leaf in enumerate_leaves(popped, keylen - len(key)): + for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))): delete_node(leaf) @synchronized - def cache_clear(): + def cache_clear() -> None: list_root.next_node = list_root list_root.prev_node = list_root for node in cache.values(): @@ -274,7 +323,7 @@ class LruCache: cached_cache_len[0] = 0 @synchronized - def cache_contains(key): + def cache_contains(key: KT) -> bool: return key in cache self.sentinel = object() -- cgit 1.5.1 From 995cc615a01bb11b70dbf8fdd0eb7f8b3d1fdc1e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 16 Oct 2020 16:14:42 +0100 Subject: Apply suggestions from code review Co-authored-by: Patrick Cloke --- synapse/push/push_rule_evaluator.py | 2 +- synapse/util/caches/lrucache.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/util/caches') diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 854ffd625e..2ce9e444ab 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -187,7 +187,7 @@ class PushRuleEvaluatorForEvent: # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches regex_cache = LruCache( 50000, "regex_push_cache" -) # type: LruCache[Tuple[str, bool, bool],Pattern] +) # type: LruCache[Tuple[str, bool, bool], Pattern] def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 0eed53d3f4..1a2c2d4c0b 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -283,7 +283,7 @@ class LruCache(Generic[KT, VT]): return value @overload - def cache_pop(key: KT, default: Literal[None] = None) -> Union[None, VT]: + def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]: ... @overload -- cgit 1.5.1 From 6d7b22041ddf5ecceaf230404ef00c4d0b432727 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 16 Oct 2020 16:21:43 +0100 Subject: review comments --- synapse/util/caches/lrucache.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'synapse/util/caches') diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1a2c2d4c0b..4e95dd9bf3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -34,11 +34,16 @@ from synapse.config import cache as cache_config from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache -T = TypeVar("T") +# Function type: the type used for invalidation callbacks FT = TypeVar("FT", bound=Callable[..., Any]) + +# Key and Value type for the cache KT = TypeVar("KT") VT = TypeVar("VT") +# a general type var, distinct from either KT or VT +T = TypeVar("T") + def enumerate_leaves(node, depth): if depth == 0: @@ -227,7 +232,7 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_get( key: KT, - default=None, + default: Optional[T] = None, callbacks: Iterable[Callable[[], None]] = [], update_metrics: bool = True, ): @@ -291,7 +296,7 @@ class LruCache(Generic[KT, VT]): ... @synchronized - def cache_pop(key: KT, default=None): + def cache_pop(key: KT, default: Optional[T] = None): node = cache.get(key, None) if node: delete_node(node) -- cgit 1.5.1 From 97647b33c248f25571bae617365d95434e6a3d5f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 19 Oct 2020 12:20:29 +0100 Subject: Replace DeferredCache with LruCache where possible (#8563) Most of these uses don't need a full-blown DeferredCache; LruCache is lighter and more appropriate. --- changelog.d/8563.misc | 1 + synapse/replication/slave/storage/client_ips.py | 10 +++++----- synapse/storage/_base.py | 12 +++++++----- synapse/storage/databases/main/client_ips.py | 8 ++++---- synapse/storage/databases/main/devices.py | 8 ++++---- synapse/storage/databases/main/events.py | 4 +--- synapse/storage/databases/main/events_worker.py | 11 +++++------ synapse/util/caches/lrucache.py | 3 +++ 8 files changed, 30 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8563.misc (limited to 'synapse/util/caches') diff --git a/changelog.d/8563.misc b/changelog.d/8563.misc new file mode 100644 index 0000000000..eeba8e5fee --- /dev/null +++ b/changelog.d/8563.misc @@ -0,0 +1 @@ +Replace `DeferredCache` with the lighter-weight `LruCache` where possible. diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 4b0ea0cc01..0f5b7adef7 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache from ._base import BaseSlavedStore @@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = DeferredCache( - name="client_ip_last_seen", keylen=4, max_entries=50000 - ) # type: DeferredCache[tuple, int] + self.client_ip_last_seen = LruCache( + cache_name="client_ip_last_seen", keylen=4, max_size=50000 + ) # type: LruCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) @@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return - self.client_ip_last_seen.prefill(key, now) + self.client_ip_last_seen.set(key, now) self.hs.get_tcp_replication().send_user_ip( user_id, access_token, ip, user_agent, device_id, now diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index ab49d227de..2b196ded1b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -76,14 +76,16 @@ class SQLBaseStore(metaclass=ABCMeta): """ try: - if key is None: - getattr(self, cache_name).invalidate_all() - else: - getattr(self, cache_name).invalidate(tuple(key)) + cache = getattr(self, cache_name) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. - pass + return + + if key is None: + cache.invalidate_all() + else: + cache.invalidate(tuple(key)) def db_to_json(db_content): diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 9e66e6648a..339bd691a4 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -410,8 +410,8 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - self.client_ip_last_seen = DeferredCache( - name="client_ip_last_seen", keylen=4, max_entries=50000 + self.client_ip_last_seen = LruCache( + cache_name="client_ip_last_seen", keylen=4, max_size=50000 ) super().__init__(database, db_conn, hs) @@ -442,7 +442,7 @@ class ClientIpStore(ClientIpWorkerStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return - self.client_ip_last_seen.prefill(key, now) + self.client_ip_last_seen.set(key, now) self._batch_row_update[key] = (user_agent, device_id, now) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e662a20d24..dfb4f87b8f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,8 +34,8 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -1005,8 +1005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = DeferredCache( - name="device_id_exists", keylen=2, max_entries=10000 + self.device_id_exists_cache = LruCache( + cache_name="device_id_exists", keylen=2, max_size=10000 ) async def store_device( @@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - self.device_id_exists_cache.prefill(key, True) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: raise diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ba3b1769b0..87808c1483 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1051,9 +1051,7 @@ class PersistEventsStore: def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.prefill( - (cache_entry[0].event_id,), cache_entry - ) + self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) txn.call_after(prefill) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 0ad9a19b3d..c342df2a8b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,8 +42,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.descriptors import cached +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -146,11 +146,10 @@ class EventsWorkerStore(SQLBaseStore): self._cleanup_old_transaction_ids, ) - self._get_event_cache = DeferredCache( - "*getEvent*", + self._get_event_cache = LruCache( + cache_name="*getEvent*", keylen=3, - max_entries=hs.config.caches.event_cache_size, - apply_cache_factor_from_config=False, + max_size=hs.config.caches.event_cache_size, ) self._event_fetch_lock = threading.Condition() @@ -749,7 +748,7 @@ class EventsWorkerStore(SQLBaseStore): event=original_ev, redacted_event=redacted_event ) - self._get_event_cache.prefill((event_id,), cache_entry) + self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry return result_map diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 4e95dd9bf3..3b471d8fd3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -337,6 +337,9 @@ class LruCache(Generic[KT, VT]): self.set = cache_set self.setdefault = cache_set_default self.pop = cache_pop + # `invalidate` is exposed for consistency with DeferredCache, so that it can be + # invalidated by the cache invalidation replication stream. + self.invalidate = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi self.len = synchronized(cache_len) -- cgit 1.5.1 From 903d11c43a5df9f704e5dad4d14506a6470524fc Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 19 Oct 2020 15:00:12 +0100 Subject: Add `DeferredCache.get_immediate` method (#8568) * Add `DeferredCache.get_immediate` method A bunch of things that are currently calling `DeferredCache.get` are only really interested in the result if it's completed. We can optimise and simplify this case. * Remove unused 'default' parameter to DeferredCache.get() * another get_immediate instance --- changelog.d/8568.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/storage/databases/main/pusher.py | 2 +- synapse/storage/databases/main/receipts.py | 11 +-------- synapse/storage/databases/main/roommember.py | 2 +- synapse/util/caches/deferred_cache.py | 35 ++++++++++++++++++++-------- tests/util/caches/test_deferred_cache.py | 27 +++++++++++++++++---- 7 files changed, 53 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8568.misc (limited to 'synapse/util/caches') diff --git a/changelog.d/8568.misc b/changelog.d/8568.misc new file mode 100644 index 0000000000..0ed7db92d3 --- /dev/null +++ b/changelog.d/8568.misc @@ -0,0 +1 @@ +Add `get_immediate` method to `DeferredCache`. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index c440f2545c..a701defcdd 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): # dedupe when we add callbacks to lru cache nodes, otherwise the number # of callbacks would grow. def __call__(self): - rules = self.cache.get(self.room_id, None, update_metrics=False) + rules = self.cache.get_immediate(self.room_id, None, update_metrics=False) if rules: rules.invalidate_all() diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index df8609b97b..7997242d90 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore): lock=False, ) - user_has_pusher = self.get_if_user_has_pusher.cache.get( + user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate( (user_id,), None, update_metrics=False ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 5cdf16521c..ca7917c989 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): if receipt_type != "m.read": return - # Returns either an ObservableDeferred or the raw result - res = self.get_users_with_read_receipts_in_room.cache.get( + res = self.get_users_with_read_receipts_in_room.cache.get_immediate( room_id, None, update_metrics=False ) - # first handle the ObservableDeferred case - if isinstance(res, ObservableDeferred): - if res.has_called(): - res = res.get_result() - else: - res = None - if res and user_id in res: # We'd only be adding to the set, so no point invalidating if the # user is already there diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 20fcdaa529..9b08b49862 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -531,7 +531,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # If we do then we can reuse that result and simply update it with # any membership changes in `delta_ids` if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get( + prev_res = self._get_joined_users_from_context.cache.get_immediate( (room_id, context.prev_group), None ) if prev_res and isinstance(prev_res, dict): diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 4026e1f8fa..faeef75506 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -17,7 +17,16 @@ import enum import threading -from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast +from typing import ( + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + TypeVar, + Union, + cast, +) from prometheus_client import Gauge @@ -33,7 +42,7 @@ cache_pending_metric = Gauge( ["name"], ) - +T = TypeVar("T") KT = TypeVar("KT") VT = TypeVar("VT") @@ -119,21 +128,21 @@ class DeferredCache(Generic[KT, VT]): def get( self, key: KT, - default=_Sentinel.sentinel, callback: Optional[Callable[[], None]] = None, update_metrics: bool = True, - ): + ) -> Union[ObservableDeferred, VT]: """Looks the key up in the caches. Args: key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead callback(fn): Gets called when the entry in the cache is invalidated update_metrics (bool): whether to update the cache hit rate metrics Returns: Either an ObservableDeferred or the result itself + + Raises: + KeyError if the key is not found in the cache """ callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) @@ -145,13 +154,19 @@ class DeferredCache(Generic[KT, VT]): m.inc_hits() return val.deferred - val = self.cache.get( - key, default, callbacks=callbacks, update_metrics=update_metrics + val2 = self.cache.get( + key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics ) - if val is _Sentinel.sentinel: + if val2 is _Sentinel.sentinel: raise KeyError() else: - return val + return val2 + + def get_immediate( + self, key: KT, default: T, update_metrics: bool = True + ) -> Union[VT, T]: + """If we have a *completed* cached value, return it.""" + return self.cache.get(key, default, update_metrics=update_metrics) def set( self, diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 9717be56b6..8a08ab6661 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -38,6 +38,22 @@ class DeferredCacheTestCase(unittest.TestCase): self.assertEquals(cache.get("foo"), 123) + def test_get_immediate(self): + cache = DeferredCache("test") + d1 = defer.Deferred() + cache.set("key1", d1) + + # get_immediate should return default + v = cache.get_immediate("key1", 1) + self.assertEqual(v, 1) + + # now complete the set + d1.callback(2) + + # get_immediate should return result + v = cache.get_immediate("key1", 1) + self.assertEqual(v, 2) + def test_invalidate(self): cache = DeferredCache("test") cache.prefill(("foo",), 123) @@ -80,9 +96,11 @@ class DeferredCacheTestCase(unittest.TestCase): # now do the invalidation cache.invalidate_all() - # lookup should return none - self.assertIsNone(cache.get("key1", None)) - self.assertIsNone(cache.get("key2", None)) + # lookup should fail + with self.assertRaises(KeyError): + cache.get("key1") + with self.assertRaises(KeyError): + cache.get("key2") # both callbacks should have been callbacked self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") @@ -90,7 +108,8 @@ class DeferredCacheTestCase(unittest.TestCase): # letting the other lookup complete should do nothing d1.callback("result1") - self.assertIsNone(cache.get("key1", None)) + with self.assertRaises(KeyError): + cache.get("key1", None) def test_eviction(self): cache = DeferredCache( -- cgit 1.5.1 From 96e7d3c4a0feec6d19b873fd550bcfffd485d910 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 19 Oct 2020 21:13:50 +0100 Subject: Fix 'LruCache' object has no attribute '_on_resize' (#8591) We need to make sure we are readu for the `set_cache_factor` callback. --- changelog.d/8591.misc | 1 + synapse/util/caches/lrucache.py | 10 +++++++++- tests/util/test_lrucache.py | 8 +++++++- 3 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8591.misc (limited to 'synapse/util/caches') diff --git a/changelog.d/8591.misc b/changelog.d/8591.misc new file mode 100644 index 0000000000..8f16bc3e7e --- /dev/null +++ b/changelog.d/8591.misc @@ -0,0 +1 @@ + Move metric registration code down into `LruCache`. diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 3b471d8fd3..60bb6ff642 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -124,6 +124,10 @@ class LruCache(Generic[KT, VT]): else: self.max_size = int(max_size) + # register_cache might call our "set_cache_factor" callback; there's nothing to + # do yet when we get resized. + self._on_resize = None # type: Optional[Callable[[],None]] + if cache_name is not None: metrics = register_cache( "lru_cache", @@ -332,7 +336,10 @@ class LruCache(Generic[KT, VT]): return key in cache self.sentinel = object() + + # make sure that we clear out any excess entries after we get resized. self._on_resize = evict + self.get = cache_get self.set = cache_set self.setdefault = cache_set_default @@ -383,6 +390,7 @@ class LruCache(Generic[KT, VT]): new_size = int(self._original_max_size * factor) if new_size != self.max_size: self.max_size = new_size - self._on_resize() + if self._on_resize: + self._on_resize() return True return False diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index f12834edab..a739a6aaaf 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -19,7 +19,8 @@ from mock import Mock from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache -from .. import unittest +from tests import unittest +from tests.unittest import override_config class LruCacheTestCase(unittest.HomeserverTestCase): @@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase): cache.clear() self.assertEquals(len(cache), 0) + @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) + def test_special_size(self): + cache = LruCache(10, "mycache") + self.assertEqual(cache.max_size, 100) + class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): -- cgit 1.5.1 From 1f4269700c2353263a605856f28ded28501368e1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 16 Oct 2020 12:34:55 +0100 Subject: Push some deferred wrangling down into DeferredCache --- changelog.d/8572.misc | 1 + synapse/util/caches/deferred_cache.py | 57 +++++++++++++++++++++++++++----- synapse/util/caches/descriptors.py | 32 ++++-------------- tests/util/caches/test_deferred_cache.py | 18 +++++----- tests/util/caches/test_descriptors.py | 5 ++- 5 files changed, 67 insertions(+), 46 deletions(-) create mode 100644 changelog.d/8572.misc (limited to 'synapse/util/caches') diff --git a/changelog.d/8572.misc b/changelog.d/8572.misc new file mode 100644 index 0000000000..ea2a6d340d --- /dev/null +++ b/changelog.d/8572.misc @@ -0,0 +1 @@ +Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s. diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index faeef75506..6c162e9f34 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -57,7 +57,7 @@ class DeferredCache(Generic[KT, VT]): """Wraps an LruCache, adding support for Deferred results. It expects that each entry added with set() will be a Deferred; likewise get() - may return an ObservableDeferred. + will return a Deferred. """ __slots__ = ( @@ -130,16 +130,22 @@ class DeferredCache(Generic[KT, VT]): key: KT, callback: Optional[Callable[[], None]] = None, update_metrics: bool = True, - ) -> Union[ObservableDeferred, VT]: + ) -> defer.Deferred: """Looks the key up in the caches. + For symmetry with set(), this method does *not* follow the synapse logcontext + rules: the logcontext will not be cleared on return, and the Deferred will run + its callbacks in the sentinel context. In other words: wrap the result with + make_deferred_yieldable() before `await`ing it. + Args: - key(tuple) - callback(fn): Gets called when the entry in the cache is invalidated + key: + callback: Gets called when the entry in the cache is invalidated update_metrics (bool): whether to update the cache hit rate metrics Returns: - Either an ObservableDeferred or the result itself + A Deferred which completes with the result. Note that this may later fail + if there is an ongoing set() operation which later completes with a failure. Raises: KeyError if the key is not found in the cache @@ -152,7 +158,7 @@ class DeferredCache(Generic[KT, VT]): m = self.cache.metrics assert m # we always have a name, so should always have metrics m.inc_hits() - return val.deferred + return val.deferred.observe() val2 = self.cache.get( key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics @@ -160,7 +166,7 @@ class DeferredCache(Generic[KT, VT]): if val2 is _Sentinel.sentinel: raise KeyError() else: - return val2 + return defer.succeed(val2) def get_immediate( self, key: KT, default: T, update_metrics: bool = True @@ -173,7 +179,36 @@ class DeferredCache(Generic[KT, VT]): key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None, - ) -> ObservableDeferred: + ) -> defer.Deferred: + """Adds a new entry to the cache (or updates an existing one). + + The given `value` *must* be a Deferred. + + First any existing entry for the same key is invalidated. Then a new entry + is added to the cache for the given key. + + Until the `value` completes, calls to `get()` for the key will also result in an + incomplete Deferred, which will ultimately complete with the same result as + `value`. + + If `value` completes successfully, subsequent calls to `get()` will then return + a completed deferred with the same result. If it *fails*, the cache is + invalidated and subequent calls to `get()` will raise a KeyError. + + If another call to `set()` happens before `value` completes, then (a) any + invalidation callbacks registered in the interim will be called, (b) any + `get()`s in the interim will continue to complete with the result from the + *original* `value`, (c) any future calls to `get()` will complete with the + result from the *new* `value`. + + It is expected that `value` does *not* follow the synapse logcontext rules - ie, + if it is incomplete, it runs its callbacks in the sentinel context. + + Args: + key: Key to be set + value: a deferred which will complete with a result to add to the cache + callback: An optional callback to be called when the entry is invalidated + """ if not isinstance(value, defer.Deferred): raise TypeError("not a Deferred") @@ -187,6 +222,8 @@ class DeferredCache(Generic[KT, VT]): if existing_entry: existing_entry.invalidate() + # XXX: why don't we invalidate the entry in `self.cache` yet? + self._pending_deferred_cache[key] = entry def compare_and_pop(): @@ -230,7 +267,9 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache to the real cache. # observer.addCallbacks(cb, eb) - return observable + + # we return a new Deferred which will be called before any subsequent observers. + return observable.observe() def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): callbacks = [callback] if callback else [] diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1f43886804..a4172345ef 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -23,7 +23,6 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) @@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase): keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) # type: DeferredCache[Tuple, Any] + ) # type: DeferredCache[CacheKey, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into @@ -208,26 +207,12 @@ class CacheDescriptor(_CacheDescriptorBase): kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) try: - cached_result_d = cache.get(cache_key, callback=invalidate_callback) - - if isinstance(cached_result_d, ObservableDeferred): - observer = cached_result_d.observe() - else: - observer = defer.succeed(cached_result_d) - + ret = cache.get(cache_key, callback=invalidate_callback) except KeyError: ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) + ret = cache.set(cache_key, ret, callback=invalidate_callback) - def onErr(f): - cache.invalidate(cache_key) - return f - - ret.addErrback(onErr) - - result_d = cache.set(cache_key, ret, callback=invalidate_callback) - observer = result_d.observe() - - return make_deferred_yieldable(observer) + return make_deferred_yieldable(ret) wrapped = cast(_CachedFunction, _wrapped) @@ -286,7 +271,7 @@ class CacheListDescriptor(_CacheDescriptorBase): def __get__(self, obj, objtype=None): cached_method = getattr(obj, self.cached_method_name) - cache = cached_method.cache + cache = cached_method.cache # type: DeferredCache[CacheKey, Any] num_args = cached_method.num_args @functools.wraps(self.orig) @@ -326,14 +311,11 @@ class CacheListDescriptor(_CacheDescriptorBase): for arg in list_args: try: res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) - if not isinstance(res, ObservableDeferred): - results[arg] = res - elif not res.has_succeeded(): - res = res.observe() + if not res.called: res.addCallback(update_results_dict, arg) cached_defers.append(res) else: - results[arg] = res.get_result() + results[arg] = res.result except KeyError: missing.add(arg) diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 8a08ab6661..68d26128c1 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from functools import partial from twisted.internet import defer from synapse.util.caches.deferred_cache import DeferredCache +from tests.unittest import TestCase -class DeferredCacheTestCase(unittest.TestCase): + +class DeferredCacheTestCase(TestCase): def test_empty(self): cache = DeferredCache("test") failed = False @@ -36,7 +37,7 @@ class DeferredCacheTestCase(unittest.TestCase): cache = DeferredCache("test") cache.prefill("foo", 123) - self.assertEquals(cache.get("foo"), 123) + self.assertEquals(self.successResultOf(cache.get("foo")), 123) def test_get_immediate(self): cache = DeferredCache("test") @@ -82,16 +83,15 @@ class DeferredCacheTestCase(unittest.TestCase): d2 = defer.Deferred() cache.set("key2", d2, partial(record_callback, 1)) - # lookup should return observable deferreds - self.assertFalse(cache.get("key1").has_called()) - self.assertFalse(cache.get("key2").has_called()) + # lookup should return pending deferreds + self.assertFalse(cache.get("key1").called) + self.assertFalse(cache.get("key2").called) # let one of the lookups complete d2.callback("result2") - # for now at least, the cache will return real results rather than an - # observabledeferred - self.assertEqual(cache.get("key2"), "result2") + # now the cache will return a completed deferred + self.assertEqual(self.successResultOf(cache.get("key2")), "result2") # now do the invalidation cache.invalidate_all() diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 3d738afa7f..fc2663c02d 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -27,7 +27,6 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) -from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import descriptors from synapse.util.caches.descriptors import cached @@ -419,9 +418,9 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): a = A() - a.func.prefill(("foo",), ObservableDeferred(d)) + a.func.prefill(("foo",), 456) - self.assertEquals(a.func("foo").result, d.result) + self.assertEquals(a.func("foo").result, 456) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks -- cgit 1.5.1 From 2b3af01791120de5d0b829395109b42870d9e465 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 16 Oct 2020 13:16:02 +0100 Subject: optimise DeferredCache.set --- changelog.d/8593.misc | 1 + synapse/util/caches/deferred_cache.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8593.misc (limited to 'synapse/util/caches') diff --git a/changelog.d/8593.misc b/changelog.d/8593.misc new file mode 100644 index 0000000000..d266ba19a4 --- /dev/null +++ b/changelog.d/8593.misc @@ -0,0 +1 @@ +Minor optimisations in caching code. diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 6c162e9f34..fc01026285 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -214,9 +214,6 @@ class DeferredCache(Generic[KT, VT]): callbacks = [callback] if callback else [] self.check_thread() - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry: @@ -224,6 +221,18 @@ class DeferredCache(Generic[KT, VT]): # XXX: why don't we invalidate the entry in `self.cache` yet? + # we can save a whole load of effort if the deferred is ready. + if value.called: + self.cache.set(key, value.result, callbacks) + return value + + # otherwise, we'll add an entry to the _pending_deferred_cache for now, + # and add callbacks to add it to the cache properly later. + + observable = ObservableDeferred(value, consumeErrors=True) + observer = observable.observe() + entry = CacheEntry(deferred=observable, callbacks=callbacks) + self._pending_deferred_cache[key] = entry def compare_and_pop(): -- cgit 1.5.1 From c13820bcee5f23119b65f9386b7c03bd4a33acbe Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 21 Oct 2020 18:54:53 +0100 Subject: fix failure case --- synapse/util/caches/deferred_cache.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'synapse/util/caches') diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index fc01026285..601305487c 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -31,6 +31,7 @@ from typing import ( from prometheus_client import Gauge from twisted.internet import defer +from twisted.python import failure from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.lrucache import LruCache @@ -223,7 +224,9 @@ class DeferredCache(Generic[KT, VT]): # we can save a whole load of effort if the deferred is ready. if value.called: - self.cache.set(key, value.result, callbacks) + result = value.result + if not isinstance(result, failure.Failure): + self.cache.set(key, result, callbacks) return value # otherwise, we'll add an entry to the _pending_deferred_cache for now, -- cgit 1.5.1 From b28aaeb3a567ce1bbfd41796362b8bb0813ed0e3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 21 Oct 2020 22:57:45 +0100 Subject: Optimise CacheDescriptor (#8594) don't bother constricting a CacheContext unless we need one. --- changelog.d/8594.misc | 1 + synapse/util/caches/descriptors.py | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8594.misc (limited to 'synapse/util/caches') diff --git a/changelog.d/8594.misc b/changelog.d/8594.misc new file mode 100644 index 0000000000..d266ba19a4 --- /dev/null +++ b/changelog.d/8594.misc @@ -0,0 +1 @@ +Minor optimisations in caching code. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index a4172345ef..5d7fffee66 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -201,14 +201,16 @@ class CacheDescriptor(_CacheDescriptorBase): cache_key = get_cache_key(args, kwargs) - # Add our own `cache_context` to argument list if the wrapped function - # has asked for one - if self.add_cache_context: - kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) - try: ret = cache.get(cache_key, callback=invalidate_callback) except KeyError: + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext.get_instance( + cache, cache_key + ) + ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) ret = cache.set(cache_key, ret, callback=invalidate_callback) -- cgit 1.5.1 From cbc82aa09faa59acc20865e8b5c36561acb9a570 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 30 Oct 2020 11:43:17 +0000 Subject: Implement and use an @lru_cache decorator (#8595) We don't always need the full power of a DeferredCache. --- changelog.d/8595.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 37 +++-- synapse/util/caches/descriptors.py | 235 ++++++++++++++++++++++++------- tests/util/caches/test_descriptors.py | 60 +++++++- 4 files changed, 272 insertions(+), 61 deletions(-) create mode 100644 changelog.d/8595.misc (limited to 'synapse/util/caches') diff --git a/changelog.d/8595.misc b/changelog.d/8595.misc new file mode 100644 index 0000000000..24fab65cda --- /dev/null +++ b/changelog.d/8595.misc @@ -0,0 +1 @@ +Implement and use an @lru_cache decorator. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d9b5478b53..82a72dc34f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -15,8 +15,8 @@ # limitations under the License. import logging -from collections import namedtuple +import attr from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, RelationTypes @@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import lru_cache +from synapse.util.caches.lrucache import LruCache from .push_rule_evaluator import PushRuleEvaluatorForEvent @@ -120,7 +121,7 @@ class BulkPushRuleEvaluator: dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = await self._get_rules_for_room(room_id) + rules_for_room = self._get_rules_for_room(room_id) rules_by_user = await rules_for_room.get_rules(event, context) @@ -138,7 +139,7 @@ class BulkPushRuleEvaluator: return rules_by_user - @cached() + @lru_cache() def _get_rules_for_room(self, room_id): """Get the current RulesForRoom object for the given room id @@ -275,12 +276,14 @@ class RulesForRoom: the entire cache for the room. """ - def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics): + def __init__( + self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics + ): """ Args: hs (HomeServer) room_id (str) - rules_for_room_cache(Cache): The cache object that caches these + rules_for_room_cache: The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics (CacheMetric) """ @@ -489,13 +492,21 @@ class RulesForRoom: self.state_group = state_group -class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): - # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us (i.e. two _CacheContext are the same if - # their caches and keys match). This is important in particular to - # dedupe when we add callbacks to lru cache nodes, otherwise the number - # of callbacks would grow. +@attr.attrs(slots=True, frozen=True) +class _Invalidation: + # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules, + # which means that it it is stored on the bulk_get_push_rules cache entry. In order + # to ensure that we don't accumulate lots of redunant callbacks on the cache entry, + # we need to ensure that two _Invalidation objects are "equal" if they refer to the + # same `cache` and `room_id`. + # + # attrs provides suitable __hash__ and __eq__ methods, provided we remember to + # set `frozen=True`. + + cache = attr.ib(type=LruCache) + room_id = attr.ib(type=str) + def __call__(self): - rules = self.cache.get_immediate(self.room_id, None, update_metrics=False) + rules = self.cache.get(self.room_id, None, update_metrics=False) if rules: rules.invalidate_all() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5d7fffee66..a924140cdf 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,10 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import functools import inspect import logging -from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) from weakref import WeakValueDictionary from twisted.internet import defer @@ -24,6 +37,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]): class _CacheDescriptorBase: - def __init__(self, orig: _CachedFunction, num_args, cache_context=False): + def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): self.orig = orig arg_spec = inspect.getfullargspec(orig) @@ -97,8 +111,107 @@ class _CacheDescriptorBase: self.add_cache_context = cache_context + self.cache_key_builder = get_cache_key_builder( + self.arg_names, self.arg_defaults + ) + + +class _LruCachedFunction(Generic[F]): + cache = None # type: LruCache[CacheKey, Any] + __call__ = None # type: F + + +def lru_cache( + max_entries: int = 1000, cache_context: bool = False, +) -> Callable[[F], _LruCachedFunction[F]]: + """A method decorator that applies a memoizing cache around the function. + + This is more-or-less a drop-in equivalent to functools.lru_cache, although note + that the signature is slightly different. + + The main differences with functools.lru_cache are: + (a) the size of the cache can be controlled via the cache_factor mechanism + (b) the wrapped function can request a "cache_context" which provides a + callback mechanism to indicate that the result is no longer valid + (c) prometheus metrics are exposed automatically. + + The function should take zero or more arguments, which are used as the key for the + cache. Single-argument functions use that argument as the cache key; otherwise the + arguments are built into a tuple. + + Cached functions can be "chained" (i.e. a cached function can call other cached + functions and get appropriately invalidated when they called caches are + invalidated) by adding a special "cache_context" argument to the function + and passing that as a kwarg to all caches called. For example: + + @lru_cache(cache_context=True) + def foo(self, key, cache_context): + r1 = self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = self.bar2(key, on_invalidate=cache_context.invalidate) + return r1 + r2 + + The wrapped function also has a 'cache' property which offers direct access to the + underlying LruCache. + """ + + def func(orig: F) -> _LruCachedFunction[F]: + desc = LruCacheDescriptor( + orig, max_entries=max_entries, cache_context=cache_context, + ) + return cast(_LruCachedFunction[F], desc) + + return func + + +class LruCacheDescriptor(_CacheDescriptorBase): + """Helper for @lru_cache""" + + class _Sentinel(enum.Enum): + sentinel = object() + + def __init__( + self, orig, max_entries: int = 1000, cache_context: bool = False, + ): + super().__init__(orig, num_args=None, cache_context=cache_context) + self.max_entries = max_entries + + def __get__(self, obj, owner): + cache = LruCache( + cache_name=self.orig.__name__, max_size=self.max_entries, + ) # type: LruCache[CacheKey, Any] + + get_cache_key = self.cache_key_builder + sentinel = LruCacheDescriptor._Sentinel.sentinel + + @functools.wraps(self.orig) + def _wrapped(*args, **kwargs): + invalidate_callback = kwargs.pop("on_invalidate", None) + callbacks = (invalidate_callback,) if invalidate_callback else () + + cache_key = get_cache_key(args, kwargs) -class CacheDescriptor(_CacheDescriptorBase): + ret = cache.get(cache_key, default=sentinel, callbacks=callbacks) + if ret != sentinel: + return ret + + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) + + ret2 = self.orig(obj, *args, **kwargs) + cache.set(cache_key, ret2, callbacks=callbacks) + + return ret2 + + wrapped = cast(_CachedFunction, _wrapped) + wrapped.cache = cache + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +class DeferredCacheDescriptor(_CacheDescriptorBase): """ A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that @@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=False, iterable=False, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) self.max_entries = max_entries @@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase): iterable=self.iterable, ) # type: DeferredCache[CacheKey, Any] - def get_cache_key_gen(args, kwargs): - """Given some args/kwargs return a generator that resolves into - the cache_key. - - We loop through each arg name, looking up if its in the `kwargs`, - otherwise using the next argument in `args`. If there are no more - args then we try looking the arg name up in the defaults - """ - pos = 0 - for nm in self.arg_names: - if nm in kwargs: - yield kwargs[nm] - elif pos < len(args): - yield args[pos] - pos += 1 - else: - yield self.arg_defaults[nm] - - # By default our cache key is a tuple, but if there is only one item - # then don't bother wrapping in a tuple. This is to save memory. - if self.num_args == 1: - nm = self.arg_names[0] - - def get_cache_key(args, kwargs): - if nm in kwargs: - return kwargs[nm] - elif len(args): - return args[0] - else: - return self.arg_defaults[nm] - - else: - - def get_cache_key(args, kwargs): - return tuple(get_cache_key_gen(args, kwargs)) + get_cache_key = self.cache_key_builder @functools.wraps(self.orig) def _wrapped(*args, **kwargs): @@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase): wrapped.prefill = lambda key, val: cache.prefill(key[0], val) else: wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_many = cache.invalidate_many wrapped.prefill = cache.prefill @@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase): return wrapped -class CacheListDescriptor(_CacheDescriptorBase): +class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes @@ -382,11 +459,13 @@ class _CacheContext: on a lower level. """ + Cache = Union[DeferredCache, LruCache] + _cache_context_objects = ( WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext] + ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext] - def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None + def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None: self._cache = cache self._cache_key = cache_key @@ -396,8 +475,8 @@ class _CacheContext: @classmethod def get_instance( - cls, cache, cache_key - ): # type: (DeferredCache, CacheKey) -> _CacheContext + cls, cache: "_CacheContext.Cache", cache_key: CacheKey + ) -> "_CacheContext": """Returns an instance constructed with the given arguments. A new instance is only created if none already exists. @@ -418,7 +497,7 @@ def cached( cache_context: bool = False, iterable: bool = False, ) -> Callable[[F], _CachedFunction[F]]: - func = lambda orig: CacheDescriptor( + func = lambda orig: DeferredCacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -460,7 +539,7 @@ def cachedList( def batch_do_something(self, first_arg, second_args): ... """ - func = lambda orig: CacheListDescriptor( + func = lambda orig: DeferredCacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, @@ -468,3 +547,65 @@ def cachedList( ) return cast(Callable[[F], _CachedFunction[F]], func) + + +def get_cache_key_builder( + param_names: Sequence[str], param_defaults: Mapping[str, Any] +) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: + """Construct a function which will build cache keys suitable for a cached function + + Args: + param_names: list of formal parameter names for the cached function + param_defaults: a mapping from parameter name to default value for that param + + Returns: + A function which will take an (args, kwargs) pair and return a cache key + """ + + # By default our cache key is a tuple, but if there is only one item + # then don't bother wrapping in a tuple. This is to save memory. + + if len(param_names) == 1: + nm = param_names[0] + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + if nm in kwargs: + return kwargs[nm] + elif len(args): + return args[0] + else: + return param_defaults[nm] + + else: + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + + return get_cache_key + + +def _get_cache_key_gen( + param_names: Iterable[str], + param_defaults: Mapping[str, Any], + args: Sequence[Any], + kwargs: Mapping[str, Any], +) -> Iterable[Any]: + """Given some args/kwargs return a generator that resolves into + the cache_key. + + This is essentially the same operation as `inspect.getcallargs`, but optimised so + that we don't need to inspect the target function for each call. + """ + + # We loop through each arg name, looking up if its in the `kwargs`, + # otherwise using the next argument in `args`. If there are no more + # args then we try looking the arg name up in the defaults. + pos = 0 + for nm in param_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield param_defaults[nm] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 2ad08f541b..cf1e3203a4 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -29,13 +29,46 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, lru_cache from tests import unittest +from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) +class LruCacheDecoratorTestCase(unittest.TestCase): + def test_base(self): + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @lru_cache() + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + obj.mock.return_value = "fish" + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, 3) + obj.mock.reset_mock() + + # the two values should now be cached + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) @@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) + def test_invalidate_cascade(self): + """Invalidations should cascade up through cache contexts""" + + class Cls: + @cached(cache_context=True) + async def func1(self, key, cache_context): + return await self.func2(key, on_invalidate=cache_context.invalidate) + + @cached(cache_context=True) + async def func2(self, key, cache_context): + return self.func3(key, on_invalidate=cache_context.invalidate) + + @lru_cache(cache_context=True) + def func3(self, key, cache_context): + self.invalidate = cache_context.invalidate + return 42 + + obj = Cls() + + top_invalidate = mock.Mock() + r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate)) + self.assertEqual(r, 42) + obj.invalidate() + top_invalidate.assert_called_once() + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached -- cgit 1.5.1