diff options
Diffstat (limited to 'synapse/util')
-rw-r--r-- | synapse/util/caches/deferred_cache.py | 5 | ||||
-rw-r--r-- | synapse/util/caches/dictionary_cache.py | 22 | ||||
-rw-r--r-- | synapse/util/caches/lrucache.py | 78 |
3 files changed, 82 insertions, 23 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 91fdc8142d..4026e1f8fa 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -98,7 +98,7 @@ class DeferredCache(Generic[KT, VT]): size_callback=(lambda d: len(d)) if iterable else None, metrics_collection_callback=metrics_cb, apply_cache_factor_from_config=apply_cache_factor_from_config, - ) + ) # type: LruCache[KT, VT] self.thread = None # type: Optional[threading.Thread] @@ -240,11 +240,12 @@ class DeferredCache(Generic[KT, VT]): self.check_thread() if not isinstance(key, tuple): raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + key = cast(KT, key) self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) + entry_dict = self._pending_deferred_cache.pop(key, None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 8b426c005b..588d2d49f2 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import logging import threading from collections import namedtuple +from typing import Any from synapse.util.caches.lrucache import LruCache @@ -38,23 +39,26 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va return len(self.value) +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + class DictionaryCache: """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. fetching a subset of dictionary keys for a particular key. """ def __init__(self, name, max_entries=1000): - self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len) + self.cache = LruCache( + max_size=max_entries, cache_name=name, size_callback=len + ) # type: LruCache[Any, DictionaryEntry] self.name = name self.sequence = 0 self.thread = None - class Sentinel: - __slots__ = [] - - self.sentinel = Sentinel() - def check_thread(self): expected_thread = self.thread if expected_thread is None: @@ -76,8 +80,8 @@ class DictionaryCache: Returns: DictionaryEntry """ - entry = self.cache.get(key, self.sentinel) - if entry is not self.sentinel: + entry = self.cache.get(key, _Sentinel.sentinel) + if entry is not _Sentinel.sentinel: if dict_keys is None: return DictionaryEntry( entry.full, entry.known_absent, dict(entry.value) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index e4804f79e0..4e95dd9bf3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -15,12 +15,35 @@ import threading from functools import wraps -from typing import Callable, Optional, Type, Union +from typing import ( + Any, + Callable, + Generic, + Iterable, + Optional, + Type, + TypeVar, + Union, + cast, + overload, +) + +from typing_extensions import Literal from synapse.config import cache as cache_config from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.treecache import TreeCache +# Function type: the type used for invalidation callbacks +FT = TypeVar("FT", bound=Callable[..., Any]) + +# Key and Value type for the cache +KT = TypeVar("KT") +VT = TypeVar("VT") + +# a general type var, distinct from either KT or VT +T = TypeVar("T") + def enumerate_leaves(node, depth): if depth == 0: @@ -42,7 +65,7 @@ class _Node: self.callbacks = callbacks -class LruCache: +class LruCache(Generic[KT, VT]): """ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. @@ -128,13 +151,13 @@ class LruCache: if metrics: metrics.inc_evictions(evicted_len) - def synchronized(f): + def synchronized(f: FT) -> FT: @wraps(f) def inner(*args, **kwargs): with lock: return f(*args, **kwargs) - return inner + return cast(FT, inner) cached_cache_len = [0] if size_callback is not None: @@ -188,8 +211,31 @@ class LruCache: node.callbacks.clear() return deleted_len + @overload + def cache_get( + key: KT, + default: Literal[None] = None, + callbacks: Iterable[Callable[[], None]] = ..., + update_metrics: bool = ..., + ) -> Optional[VT]: + ... + + @overload + def cache_get( + key: KT, + default: T, + callbacks: Iterable[Callable[[], None]] = ..., + update_metrics: bool = ..., + ) -> Union[T, VT]: + ... + @synchronized - def cache_get(key, default=None, callbacks=[], update_metrics=True): + def cache_get( + key: KT, + default: Optional[T] = None, + callbacks: Iterable[Callable[[], None]] = [], + update_metrics: bool = True, + ): node = cache.get(key, None) if node is not None: move_node_to_front(node) @@ -203,7 +249,7 @@ class LruCache: return default @synchronized - def cache_set(key, value, callbacks=[]): + def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []): node = cache.get(key, None) if node is not None: # We sometimes store large objects, e.g. dicts, which cause @@ -232,7 +278,7 @@ class LruCache: evict() @synchronized - def cache_set_default(key, value): + def cache_set_default(key: KT, value: VT) -> VT: node = cache.get(key, None) if node is not None: return node.value @@ -241,8 +287,16 @@ class LruCache: evict() return value + @overload + def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]: + ... + + @overload + def cache_pop(key: KT, default: T) -> Union[T, VT]: + ... + @synchronized - def cache_pop(key, default=None): + def cache_pop(key: KT, default: Optional[T] = None): node = cache.get(key, None) if node: delete_node(node) @@ -252,18 +306,18 @@ class LruCache: return default @synchronized - def cache_del_multi(key): + def cache_del_multi(key: KT) -> None: """ This will only work if constructed with cache_type=TreeCache """ popped = cache.pop(key) if popped is None: return - for leaf in enumerate_leaves(popped, keylen - len(key)): + for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))): delete_node(leaf) @synchronized - def cache_clear(): + def cache_clear() -> None: list_root.next_node = list_root list_root.prev_node = list_root for node in cache.values(): @@ -274,7 +328,7 @@ class LruCache: cached_cache_len[0] = 0 @synchronized - def cache_contains(key): + def cache_contains(key: KT) -> bool: return key in cache self.sentinel = object() |