summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/deferred_cache.py5
-rw-r--r--synapse/util/caches/dictionary_cache.py22
-rw-r--r--synapse/util/caches/lrucache.py78
3 files changed, 82 insertions, 23 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 91fdc8142d..4026e1f8fa 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -98,7 +98,7 @@ class DeferredCache(Generic[KT, VT]):
             size_callback=(lambda d: len(d)) if iterable else None,
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
-        )
+        )  # type: LruCache[KT, VT]
 
         self.thread = None  # type: Optional[threading.Thread]
 
@@ -240,11 +240,12 @@ class DeferredCache(Generic[KT, VT]):
         self.check_thread()
         if not isinstance(key, tuple):
             raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+        key = cast(KT, key)
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
         # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
+        entry_dict = self._pending_deferred_cache.pop(key, None)
         if entry_dict is not None:
             for entry in iterate_tree_cache_entry(entry_dict):
                 entry.invalidate()
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 8b426c005b..588d2d49f2 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -12,10 +12,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 import logging
 import threading
 from collections import namedtuple
+from typing import Any
 
 from synapse.util.caches.lrucache import LruCache
 
@@ -38,23 +39,26 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
         return len(self.value)
 
 
+class _Sentinel(enum.Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
+
+
 class DictionaryCache:
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
     """
 
     def __init__(self, name, max_entries=1000):
-        self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len)
+        self.cache = LruCache(
+            max_size=max_entries, cache_name=name, size_callback=len
+        )  # type: LruCache[Any, DictionaryEntry]
 
         self.name = name
         self.sequence = 0
         self.thread = None
 
-        class Sentinel:
-            __slots__ = []
-
-        self.sentinel = Sentinel()
-
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
@@ -76,8 +80,8 @@ class DictionaryCache:
         Returns:
             DictionaryEntry
         """
-        entry = self.cache.get(key, self.sentinel)
-        if entry is not self.sentinel:
+        entry = self.cache.get(key, _Sentinel.sentinel)
+        if entry is not _Sentinel.sentinel:
             if dict_keys is None:
                 return DictionaryEntry(
                     entry.full, entry.known_absent, dict(entry.value)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e4804f79e0..4e95dd9bf3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,12 +15,35 @@
 
 import threading
 from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
+
+from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
 from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.treecache import TreeCache
 
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
 
 def enumerate_leaves(node, depth):
     if depth == 0:
@@ -42,7 +65,7 @@ class _Node:
         self.callbacks = callbacks
 
 
-class LruCache:
+class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
 
@@ -128,13 +151,13 @@ class LruCache:
                 if metrics:
                     metrics.inc_evictions(evicted_len)
 
-        def synchronized(f):
+        def synchronized(f: FT) -> FT:
             @wraps(f)
             def inner(*args, **kwargs):
                 with lock:
                     return f(*args, **kwargs)
 
-            return inner
+            return cast(FT, inner)
 
         cached_cache_len = [0]
         if size_callback is not None:
@@ -188,8 +211,31 @@ class LruCache:
             node.callbacks.clear()
             return deleted_len
 
+        @overload
+        def cache_get(
+            key: KT,
+            default: Literal[None] = None,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_get(
+            key: KT,
+            default: T,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_get(key, default=None, callbacks=[], update_metrics=True):
+        def cache_get(
+            key: KT,
+            default: Optional[T] = None,
+            callbacks: Iterable[Callable[[], None]] = [],
+            update_metrics: bool = True,
+        ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
@@ -203,7 +249,7 @@ class LruCache:
                 return default
 
         @synchronized
-        def cache_set(key, value, callbacks=[]):
+        def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
             node = cache.get(key, None)
             if node is not None:
                 # We sometimes store large objects, e.g. dicts, which cause
@@ -232,7 +278,7 @@ class LruCache:
             evict()
 
         @synchronized
-        def cache_set_default(key, value):
+        def cache_set_default(key: KT, value: VT) -> VT:
             node = cache.get(key, None)
             if node is not None:
                 return node.value
@@ -241,8 +287,16 @@ class LruCache:
                 evict()
                 return value
 
+        @overload
+        def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_pop(key: KT, default: T) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_pop(key, default=None):
+        def cache_pop(key: KT, default: Optional[T] = None):
             node = cache.get(key, None)
             if node:
                 delete_node(node)
@@ -252,18 +306,18 @@ class LruCache:
                 return default
 
         @synchronized
-        def cache_del_multi(key):
+        def cache_del_multi(key: KT) -> None:
             """
             This will only work if constructed with cache_type=TreeCache
             """
             popped = cache.pop(key)
             if popped is None:
                 return
-            for leaf in enumerate_leaves(popped, keylen - len(key)):
+            for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
                 delete_node(leaf)
 
         @synchronized
-        def cache_clear():
+        def cache_clear() -> None:
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
@@ -274,7 +328,7 @@ class LruCache:
                 cached_cache_len[0] = 0
 
         @synchronized
-        def cache_contains(key):
+        def cache_contains(key: KT) -> bool:
             return key in cache
 
         self.sentinel = object()