diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e4804f79e0..0eed53d3f4 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,12 +15,30 @@
import threading
from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+
+from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
+T = TypeVar("T")
+FT = TypeVar("FT", bound=Callable[..., Any])
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
def enumerate_leaves(node, depth):
if depth == 0:
@@ -42,7 +60,7 @@ class _Node:
self.callbacks = callbacks
-class LruCache:
+class LruCache(Generic[KT, VT]):
"""
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@@ -128,13 +146,13 @@ class LruCache:
if metrics:
metrics.inc_evictions(evicted_len)
- def synchronized(f):
+ def synchronized(f: FT) -> FT:
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
- return inner
+ return cast(FT, inner)
cached_cache_len = [0]
if size_callback is not None:
@@ -188,8 +206,31 @@ class LruCache:
node.callbacks.clear()
return deleted_len
+ @overload
+ def cache_get(
+ key: KT,
+ default: Literal[None] = None,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Optional[VT]:
+ ...
+
+ @overload
+ def cache_get(
+ key: KT,
+ default: T,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_get(key, default=None, callbacks=[], update_metrics=True):
+ def cache_get(
+ key: KT,
+ default=None,
+ callbacks: Iterable[Callable[[], None]] = [],
+ update_metrics: bool = True,
+ ):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
@@ -203,7 +244,7 @@ class LruCache:
return default
@synchronized
- def cache_set(key, value, callbacks=[]):
+ def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@@ -232,7 +273,7 @@ class LruCache:
evict()
@synchronized
- def cache_set_default(key, value):
+ def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None)
if node is not None:
return node.value
@@ -241,8 +282,16 @@ class LruCache:
evict()
return value
+ @overload
+ def cache_pop(key: KT, default: Literal[None] = None) -> Union[None, VT]:
+ ...
+
+ @overload
+ def cache_pop(key: KT, default: T) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_pop(key, default=None):
+ def cache_pop(key: KT, default=None):
node = cache.get(key, None)
if node:
delete_node(node)
@@ -252,18 +301,18 @@ class LruCache:
return default
@synchronized
- def cache_del_multi(key):
+ def cache_del_multi(key: KT) -> None:
"""
This will only work if constructed with cache_type=TreeCache
"""
popped = cache.pop(key)
if popped is None:
return
- for leaf in enumerate_leaves(popped, keylen - len(key)):
+ for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf)
@synchronized
- def cache_clear():
+ def cache_clear() -> None:
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
@@ -274,7 +323,7 @@ class LruCache:
cached_cache_len[0] = 0
@synchronized
- def cache_contains(key):
+ def cache_contains(key: KT) -> bool:
return key in cache
self.sentinel = object()
|