diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index a0a7a9de32..eb96f7e665 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,14 +15,15 @@
import logging
import threading
import weakref
+from enum import Enum
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
+ Dict,
Generic,
- Iterable,
List,
Optional,
Type,
@@ -190,7 +191,7 @@ class _Node(Generic[KT, VT]):
root: "ListNode[_Node]",
key: KT,
value: VT,
- cache: "weakref.ReferenceType[LruCache]",
+ cache: "weakref.ReferenceType[LruCache[KT, VT]]",
clock: Clock,
callbacks: Collection[Callable[[], None]] = (),
prune_unread_entries: bool = True,
@@ -270,7 +271,10 @@ class _Node(Generic[KT, VT]):
removed from all lists.
"""
cache = self._cache()
- if not cache or not cache.pop(self.key, None):
+ if (
+ cache is None
+ or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel
+ ):
# `cache.pop` should call `drop_from_lists()`, unless this Node had
# already been removed from the cache.
self.drop_from_lists()
@@ -290,6 +294,12 @@ class _Node(Generic[KT, VT]):
self._global_list_node.update_last_access(clock)
+class _Sentinel(Enum):
+ # defining a sentinel in this way allows mypy to correctly handle the
+ # type of a dictionary lookup.
+ sentinel = object()
+
+
class LruCache(Generic[KT, VT]):
"""
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@@ -302,7 +312,7 @@ class LruCache(Generic[KT, VT]):
max_size: int,
cache_name: Optional[str] = None,
cache_type: Type[Union[dict, TreeCache]] = dict,
- size_callback: Optional[Callable] = None,
+ size_callback: Optional[Callable[[VT], int]] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
clock: Optional[Clock] = None,
@@ -339,7 +349,7 @@ class LruCache(Generic[KT, VT]):
else:
real_clock = clock
- cache = cache_type()
+ cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
self.cache = cache # Used for introspection.
self.apply_cache_factor_from_config = apply_cache_factor_from_config
@@ -374,7 +384,7 @@ class LruCache(Generic[KT, VT]):
# creating more each time we create a `_Node`.
weak_ref_to_self = weakref.ref(self)
- list_root = ListNode[_Node].create_root_node()
+ list_root = ListNode[_Node[KT, VT]].create_root_node()
lock = threading.Lock()
@@ -422,7 +432,7 @@ class LruCache(Generic[KT, VT]):
def add_node(
key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
) -> None:
- node = _Node(
+ node: _Node[KT, VT] = _Node(
list_root,
key,
value,
@@ -439,10 +449,10 @@ class LruCache(Generic[KT, VT]):
if caches.TRACK_MEMORY_USAGE and metrics:
metrics.inc_memory_usage(node.memory)
- def move_node_to_front(node: _Node) -> None:
+ def move_node_to_front(node: _Node[KT, VT]) -> None:
node.move_to_front(real_clock, list_root)
- def delete_node(node: _Node) -> int:
+ def delete_node(node: _Node[KT, VT]) -> int:
node.drop_from_lists()
deleted_len = 1
@@ -496,7 +506,7 @@ class LruCache(Generic[KT, VT]):
@synchronized
def cache_set(
- key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = ()
+ key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
) -> None:
node = cache.get(key, None)
if node is not None:
@@ -590,8 +600,6 @@ class LruCache(Generic[KT, VT]):
def cache_contains(key: KT) -> bool:
return key in cache
- self.sentinel = object()
-
# make sure that we clear out any excess entries after we get resized.
self._on_resize = evict
@@ -608,18 +616,18 @@ class LruCache(Generic[KT, VT]):
self.clear = cache_clear
def __getitem__(self, key: KT) -> VT:
- result = self.get(key, self.sentinel)
- if result is self.sentinel:
+ result = self.get(key, _Sentinel.sentinel)
+ if result is _Sentinel.sentinel:
raise KeyError()
else:
- return cast(VT, result)
+ return result
def __setitem__(self, key: KT, value: VT) -> None:
self.set(key, value)
def __delitem__(self, key: KT, value: VT) -> None:
- result = self.pop(key, self.sentinel)
- if result is self.sentinel:
+ result = self.pop(key, _Sentinel.sentinel)
+ if result is _Sentinel.sentinel:
raise KeyError()
def __len__(self) -> int:
|