summary refs log tree commit diff
path: root/synapse/util/caches/lrucache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/lrucache.py')
-rw-r--r--synapse/util/caches/lrucache.py42
1 files changed, 17 insertions, 25 deletions
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index eb96f7e665..a0a7a9de32 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,15 +15,14 @@
 import logging
 import threading
 import weakref
-from enum import Enum
 from functools import wraps
 from typing import (
     TYPE_CHECKING,
     Any,
     Callable,
     Collection,
-    Dict,
     Generic,
+    Iterable,
     List,
     Optional,
     Type,
@@ -191,7 +190,7 @@ class _Node(Generic[KT, VT]):
         root: "ListNode[_Node]",
         key: KT,
         value: VT,
-        cache: "weakref.ReferenceType[LruCache[KT, VT]]",
+        cache: "weakref.ReferenceType[LruCache]",
         clock: Clock,
         callbacks: Collection[Callable[[], None]] = (),
         prune_unread_entries: bool = True,
@@ -271,10 +270,7 @@ class _Node(Generic[KT, VT]):
         removed from all lists.
         """
         cache = self._cache()
-        if (
-            cache is None
-            or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel
-        ):
+        if not cache or not cache.pop(self.key, None):
             # `cache.pop` should call `drop_from_lists()`, unless this Node had
             # already been removed from the cache.
             self.drop_from_lists()
@@ -294,12 +290,6 @@ class _Node(Generic[KT, VT]):
             self._global_list_node.update_last_access(clock)
 
 
-class _Sentinel(Enum):
-    # defining a sentinel in this way allows mypy to correctly handle the
-    # type of a dictionary lookup.
-    sentinel = object()
-
-
 class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@@ -312,7 +302,7 @@ class LruCache(Generic[KT, VT]):
         max_size: int,
         cache_name: Optional[str] = None,
         cache_type: Type[Union[dict, TreeCache]] = dict,
-        size_callback: Optional[Callable[[VT], int]] = None,
+        size_callback: Optional[Callable] = None,
         metrics_collection_callback: Optional[Callable[[], None]] = None,
         apply_cache_factor_from_config: bool = True,
         clock: Optional[Clock] = None,
@@ -349,7 +339,7 @@ class LruCache(Generic[KT, VT]):
         else:
             real_clock = clock
 
-        cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
+        cache = cache_type()
         self.cache = cache  # Used for introspection.
         self.apply_cache_factor_from_config = apply_cache_factor_from_config
 
@@ -384,7 +374,7 @@ class LruCache(Generic[KT, VT]):
         # creating more each time we create a `_Node`.
         weak_ref_to_self = weakref.ref(self)
 
-        list_root = ListNode[_Node[KT, VT]].create_root_node()
+        list_root = ListNode[_Node].create_root_node()
 
         lock = threading.Lock()
 
@@ -432,7 +422,7 @@ class LruCache(Generic[KT, VT]):
         def add_node(
             key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
         ) -> None:
-            node: _Node[KT, VT] = _Node(
+            node = _Node(
                 list_root,
                 key,
                 value,
@@ -449,10 +439,10 @@ class LruCache(Generic[KT, VT]):
             if caches.TRACK_MEMORY_USAGE and metrics:
                 metrics.inc_memory_usage(node.memory)
 
-        def move_node_to_front(node: _Node[KT, VT]) -> None:
+        def move_node_to_front(node: _Node) -> None:
             node.move_to_front(real_clock, list_root)
 
-        def delete_node(node: _Node[KT, VT]) -> int:
+        def delete_node(node: _Node) -> int:
             node.drop_from_lists()
 
             deleted_len = 1
@@ -506,7 +496,7 @@ class LruCache(Generic[KT, VT]):
 
         @synchronized
         def cache_set(
-            key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
+            key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
         ) -> None:
             node = cache.get(key, None)
             if node is not None:
@@ -600,6 +590,8 @@ class LruCache(Generic[KT, VT]):
         def cache_contains(key: KT) -> bool:
             return key in cache
 
+        self.sentinel = object()
+
         # make sure that we clear out any excess entries after we get resized.
         self._on_resize = evict
 
@@ -616,18 +608,18 @@ class LruCache(Generic[KT, VT]):
         self.clear = cache_clear
 
     def __getitem__(self, key: KT) -> VT:
-        result = self.get(key, _Sentinel.sentinel)
-        if result is _Sentinel.sentinel:
+        result = self.get(key, self.sentinel)
+        if result is self.sentinel:
             raise KeyError()
         else:
-            return result
+            return cast(VT, result)
 
     def __setitem__(self, key: KT, value: VT) -> None:
         self.set(key, value)
 
     def __delitem__(self, key: KT, value: VT) -> None:
-        result = self.pop(key, _Sentinel.sentinel)
-        if result is _Sentinel.sentinel:
+        result = self.pop(key, self.sentinel)
+        if result is self.sentinel:
             raise KeyError()
 
     def __len__(self) -> int: