summary refs log tree commit diff
path: root/synapse/util/caches/lrucache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r--synapse/util/caches/lrucache.py78
1 files changed, 66 insertions, 12 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e4804f79e0..4e95dd9bf3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,12 +15,35 @@
 
 import threading
 from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
+
+from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
 from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.treecache import TreeCache
 
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
 
 def enumerate_leaves(node, depth):
     if depth == 0:
@@ -42,7 +65,7 @@ class _Node:
         self.callbacks = callbacks
 
 
-class LruCache:
+class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
 
@@ -128,13 +151,13 @@ class LruCache:
                 if metrics:
                     metrics.inc_evictions(evicted_len)
 
-        def synchronized(f):
+        def synchronized(f: FT) -> FT:
             @wraps(f)
             def inner(*args, **kwargs):
                 with lock:
                     return f(*args, **kwargs)
 
-            return inner
+            return cast(FT, inner)
 
         cached_cache_len = [0]
         if size_callback is not None:
@@ -188,8 +211,31 @@ class LruCache:
             node.callbacks.clear()
             return deleted_len
 
+        @overload
+        def cache_get(
+            key: KT,
+            default: Literal[None] = None,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_get(
+            key: KT,
+            default: T,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_get(key, default=None, callbacks=[], update_metrics=True):
+        def cache_get(
+            key: KT,
+            default: Optional[T] = None,
+            callbacks: Iterable[Callable[[], None]] = [],
+            update_metrics: bool = True,
+        ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
@@ -203,7 +249,7 @@ class LruCache:
                 return default
 
         @synchronized
-        def cache_set(key, value, callbacks=[]):
+        def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
             node = cache.get(key, None)
             if node is not None:
                 # We sometimes store large objects, e.g. dicts, which cause
@@ -232,7 +278,7 @@ class LruCache:
             evict()
 
         @synchronized
-        def cache_set_default(key, value):
+        def cache_set_default(key: KT, value: VT) -> VT:
             node = cache.get(key, None)
             if node is not None:
                 return node.value
@@ -241,8 +287,16 @@ class LruCache:
                 evict()
                 return value
 
+        @overload
+        def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_pop(key: KT, default: T) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_pop(key, default=None):
+        def cache_pop(key: KT, default: Optional[T] = None):
             node = cache.get(key, None)
             if node:
                 delete_node(node)
@@ -252,18 +306,18 @@ class LruCache:
                 return default
 
         @synchronized
-        def cache_del_multi(key):
+        def cache_del_multi(key: KT) -> None:
             """
             This will only work if constructed with cache_type=TreeCache
             """
             popped = cache.pop(key)
             if popped is None:
                 return
-            for leaf in enumerate_leaves(popped, keylen - len(key)):
+            for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
                 delete_node(leaf)
 
         @synchronized
-        def cache_clear():
+        def cache_clear() -> None:
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
@@ -274,7 +328,7 @@ class LruCache:
                 cached_cache_len[0] = 0
 
         @synchronized
-        def cache_contains(key):
+        def cache_contains(key: KT) -> bool:
             return key in cache
 
         self.sentinel = object()