diff --git a/changelog.d/8562.misc b/changelog.d/8562.misc
new file mode 100644
index 0000000000..ebdbddb500
--- /dev/null
+++ b/changelog.d/8562.misc
@@ -0,0 +1 @@
+Add type annotations for `LruCache`.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index eb6f418b13..bff87fabde 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -69,7 +69,9 @@ class Auth:
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- self.token_cache = LruCache(10000, "token_cache")
+ self.token_cache = LruCache(
+ 10000, "token_cache"
+ ) # type: LruCache[str, Tuple[str, bool]]
self._auth_blocking = AuthBlocking(self.hs)
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4c95b149c5..2ce9e444ab 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,7 +16,7 @@
import logging
import re
-from typing import Any, Dict, List, Optional, Pattern, Union
+from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase
from synapse.types import UserID
@@ -173,19 +173,21 @@ class PushRuleEvaluatorForEvent:
# Similar to _glob_matches, but do not treat display_name as a glob.
r = regex_cache.get((display_name, False, True), None)
if not r:
- r = re.escape(display_name)
- r = _re_word_boundary(r)
- r = re.compile(r, flags=re.IGNORECASE)
+ r1 = re.escape(display_name)
+ r1 = _re_word_boundary(r1)
+ r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r
- return r.search(body)
+ return bool(r.search(body))
def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
-regex_cache = LruCache(50000, "regex_push_cache")
+regex_cache = LruCache(
+ 50000, "regex_push_cache"
+) # type: LruCache[Tuple[str, bool, bool], Pattern]
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
@@ -203,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
if not r:
r = _glob_to_re(glob, word_boundary)
regex_cache[(glob, True, word_boundary)] = r
- return r.search(value)
+ return bool(r.search(value))
except re.error:
logger.warning("Failed to parse glob to regex: %r", glob)
return False
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 91fdc8142d..4026e1f8fa 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -98,7 +98,7 @@ class DeferredCache(Generic[KT, VT]):
size_callback=(lambda d: len(d)) if iterable else None,
metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config,
- )
+ ) # type: LruCache[KT, VT]
self.thread = None # type: Optional[threading.Thread]
@@ -240,11 +240,12 @@ class DeferredCache(Generic[KT, VT]):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+ key = cast(KT, key)
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
- entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
+ entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 8b426c005b..588d2d49f2 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
import logging
import threading
from collections import namedtuple
+from typing import Any
from synapse.util.caches.lrucache import LruCache
@@ -38,23 +39,26 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
return len(self.value)
+class _Sentinel(enum.Enum):
+ # defining a sentinel in this way allows mypy to correctly handle the
+ # type of a dictionary lookup.
+ sentinel = object()
+
+
class DictionaryCache:
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name, max_entries=1000):
- self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len)
+ self.cache = LruCache(
+ max_size=max_entries, cache_name=name, size_callback=len
+ ) # type: LruCache[Any, DictionaryEntry]
self.name = name
self.sequence = 0
self.thread = None
- class Sentinel:
- __slots__ = []
-
- self.sentinel = Sentinel()
-
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@@ -76,8 +80,8 @@ class DictionaryCache:
Returns:
DictionaryEntry
"""
- entry = self.cache.get(key, self.sentinel)
- if entry is not self.sentinel:
+ entry = self.cache.get(key, _Sentinel.sentinel)
+ if entry is not _Sentinel.sentinel:
if dict_keys is None:
return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e4804f79e0..4e95dd9bf3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,12 +15,35 @@
import threading
from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Iterable,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+
+from typing_extensions import Literal
from synapse.config import cache as cache_config
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
def enumerate_leaves(node, depth):
if depth == 0:
@@ -42,7 +65,7 @@ class _Node:
self.callbacks = callbacks
-class LruCache:
+class LruCache(Generic[KT, VT]):
"""
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@@ -128,13 +151,13 @@ class LruCache:
if metrics:
metrics.inc_evictions(evicted_len)
- def synchronized(f):
+ def synchronized(f: FT) -> FT:
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
- return inner
+ return cast(FT, inner)
cached_cache_len = [0]
if size_callback is not None:
@@ -188,8 +211,31 @@ class LruCache:
node.callbacks.clear()
return deleted_len
+ @overload
+ def cache_get(
+ key: KT,
+ default: Literal[None] = None,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Optional[VT]:
+ ...
+
+ @overload
+ def cache_get(
+ key: KT,
+ default: T,
+ callbacks: Iterable[Callable[[], None]] = ...,
+ update_metrics: bool = ...,
+ ) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_get(key, default=None, callbacks=[], update_metrics=True):
+ def cache_get(
+ key: KT,
+ default: Optional[T] = None,
+ callbacks: Iterable[Callable[[], None]] = [],
+ update_metrics: bool = True,
+ ):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
@@ -203,7 +249,7 @@ class LruCache:
return default
@synchronized
- def cache_set(key, value, callbacks=[]):
+ def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None)
if node is not None:
# We sometimes store large objects, e.g. dicts, which cause
@@ -232,7 +278,7 @@ class LruCache:
evict()
@synchronized
- def cache_set_default(key, value):
+ def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None)
if node is not None:
return node.value
@@ -241,8 +287,16 @@ class LruCache:
evict()
return value
+ @overload
+ def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+ ...
+
+ @overload
+ def cache_pop(key: KT, default: T) -> Union[T, VT]:
+ ...
+
@synchronized
- def cache_pop(key, default=None):
+ def cache_pop(key: KT, default: Optional[T] = None):
node = cache.get(key, None)
if node:
delete_node(node)
@@ -252,18 +306,18 @@ class LruCache:
return default
@synchronized
- def cache_del_multi(key):
+ def cache_del_multi(key: KT) -> None:
"""
This will only work if constructed with cache_type=TreeCache
"""
popped = cache.pop(key)
if popped is None:
return
- for leaf in enumerate_leaves(popped, keylen - len(key)):
+ for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf)
@synchronized
- def cache_clear():
+ def cache_clear() -> None:
list_root.next_node = list_root
list_root.prev_node = list_root
for node in cache.values():
@@ -274,7 +328,7 @@ class LruCache:
cached_cache_len[0] = 0
@synchronized
- def cache_contains(key):
+ def cache_contains(key: KT) -> bool:
return key in cache
self.sentinel = object()
|