summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-10-16 17:06:50 +0100
committerGitHub <noreply@github.com>2020-10-16 17:06:50 +0100
commitd6094176d1828b7d169016910f5381b90310a885 (patch)
tree791f582ee2210404f9fe8c1c8317ee42d7080453 /synapse
parentClean-up old transaction IDs on the background worker. (#8544) (diff)
parentreview comments (diff)
downloadsynapse-d6094176d1828b7d169016910f5381b90310a885.tar.xz
Type annotations for LruCache (#8562)
* type annotations for LruCache

* changelog

* Apply suggestions from code review

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>

* review comments

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/push/push_rule_evaluator.py16
-rw-r--r--synapse/util/caches/deferred_cache.py5
-rw-r--r--synapse/util/caches/dictionary_cache.py22
-rw-r--r--synapse/util/caches/lrucache.py78
5 files changed, 94 insertions, 31 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index eb6f418b13..bff87fabde 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -69,7 +69,9 @@ class Auth:
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
 
-        self.token_cache = LruCache(10000, "token_cache")
+        self.token_cache = LruCache(
+            10000, "token_cache"
+        )  # type: LruCache[str, Tuple[str, bool]]
 
         self._auth_blocking = AuthBlocking(self.hs)
 
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4c95b149c5..2ce9e444ab 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,7 +16,7 @@
 
 import logging
 import re
-from typing import Any, Dict, List, Optional, Pattern, Union
+from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
 
 from synapse.events import EventBase
 from synapse.types import UserID
@@ -173,19 +173,21 @@ class PushRuleEvaluatorForEvent:
         # Similar to _glob_matches, but do not treat display_name as a glob.
         r = regex_cache.get((display_name, False, True), None)
         if not r:
-            r = re.escape(display_name)
-            r = _re_word_boundary(r)
-            r = re.compile(r, flags=re.IGNORECASE)
+            r1 = re.escape(display_name)
+            r1 = _re_word_boundary(r1)
+            r = re.compile(r1, flags=re.IGNORECASE)
             regex_cache[(display_name, False, True)] = r
 
-        return r.search(body)
+        return bool(r.search(body))
 
     def _get_value(self, dotted_key: str) -> Optional[str]:
         return self._value_cache.get(dotted_key, None)
 
 
 # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
-regex_cache = LruCache(50000, "regex_push_cache")
+regex_cache = LruCache(
+    50000, "regex_push_cache"
+)  # type: LruCache[Tuple[str, bool, bool], Pattern]
 
 
 def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
@@ -203,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
         if not r:
             r = _glob_to_re(glob, word_boundary)
             regex_cache[(glob, True, word_boundary)] = r
-        return r.search(value)
+        return bool(r.search(value))
     except re.error:
         logger.warning("Failed to parse glob to regex: %r", glob)
         return False
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 91fdc8142d..4026e1f8fa 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -98,7 +98,7 @@ class DeferredCache(Generic[KT, VT]):
             size_callback=(lambda d: len(d)) if iterable else None,
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
-        )
+        )  # type: LruCache[KT, VT]
 
         self.thread = None  # type: Optional[threading.Thread]
 
@@ -240,11 +240,12 @@ class DeferredCache(Generic[KT, VT]):
         self.check_thread()
         if not isinstance(key, tuple):
             raise TypeError("The cache key must be a tuple not %r" % (type(key),))
+        key = cast(KT, key)
         self.cache.del_multi(key)
 
         # if we have a pending lookup for this key, remove it from the
         # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
+        entry_dict = self._pending_deferred_cache.pop(key, None)
         if entry_dict is not None:
             for entry in iterate_tree_cache_entry(entry_dict):
                 entry.invalidate()
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 8b426c005b..588d2d49f2 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -12,10 +12,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 import logging
 import threading
 from collections import namedtuple
+from typing import Any
 
 from synapse.util.caches.lrucache import LruCache
 
@@ -38,23 +39,26 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
         return len(self.value)
 
 
+class _Sentinel(enum.Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
+
+
 class DictionaryCache:
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
     """
 
     def __init__(self, name, max_entries=1000):
-        self.cache = LruCache(max_size=max_entries, cache_name=name, size_callback=len)
+        self.cache = LruCache(
+            max_size=max_entries, cache_name=name, size_callback=len
+        )  # type: LruCache[Any, DictionaryEntry]
 
         self.name = name
         self.sequence = 0
         self.thread = None
 
-        class Sentinel:
-            __slots__ = []
-
-        self.sentinel = Sentinel()
-
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
@@ -76,8 +80,8 @@ class DictionaryCache:
         Returns:
             DictionaryEntry
         """
-        entry = self.cache.get(key, self.sentinel)
-        if entry is not self.sentinel:
+        entry = self.cache.get(key, _Sentinel.sentinel)
+        if entry is not _Sentinel.sentinel:
             if dict_keys is None:
                 return DictionaryEntry(
                     entry.full, entry.known_absent, dict(entry.value)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e4804f79e0..4e95dd9bf3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -15,12 +15,35 @@
 
 import threading
 from functools import wraps
-from typing import Callable, Optional, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+)
+
+from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
 from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.treecache import TreeCache
 
+# Function type: the type used for invalidation callbacks
+FT = TypeVar("FT", bound=Callable[..., Any])
+
+# Key and Value type for the cache
+KT = TypeVar("KT")
+VT = TypeVar("VT")
+
+# a general type var, distinct from either KT or VT
+T = TypeVar("T")
+
 
 def enumerate_leaves(node, depth):
     if depth == 0:
@@ -42,7 +65,7 @@ class _Node:
         self.callbacks = callbacks
 
 
-class LruCache:
+class LruCache(Generic[KT, VT]):
     """
     Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
 
@@ -128,13 +151,13 @@ class LruCache:
                 if metrics:
                     metrics.inc_evictions(evicted_len)
 
-        def synchronized(f):
+        def synchronized(f: FT) -> FT:
             @wraps(f)
             def inner(*args, **kwargs):
                 with lock:
                     return f(*args, **kwargs)
 
-            return inner
+            return cast(FT, inner)
 
         cached_cache_len = [0]
         if size_callback is not None:
@@ -188,8 +211,31 @@ class LruCache:
             node.callbacks.clear()
             return deleted_len
 
+        @overload
+        def cache_get(
+            key: KT,
+            default: Literal[None] = None,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_get(
+            key: KT,
+            default: T,
+            callbacks: Iterable[Callable[[], None]] = ...,
+            update_metrics: bool = ...,
+        ) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_get(key, default=None, callbacks=[], update_metrics=True):
+        def cache_get(
+            key: KT,
+            default: Optional[T] = None,
+            callbacks: Iterable[Callable[[], None]] = [],
+            update_metrics: bool = True,
+        ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
@@ -203,7 +249,7 @@ class LruCache:
                 return default
 
         @synchronized
-        def cache_set(key, value, callbacks=[]):
+        def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
             node = cache.get(key, None)
             if node is not None:
                 # We sometimes store large objects, e.g. dicts, which cause
@@ -232,7 +278,7 @@ class LruCache:
             evict()
 
         @synchronized
-        def cache_set_default(key, value):
+        def cache_set_default(key: KT, value: VT) -> VT:
             node = cache.get(key, None)
             if node is not None:
                 return node.value
@@ -241,8 +287,16 @@ class LruCache:
                 evict()
                 return value
 
+        @overload
+        def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
+            ...
+
+        @overload
+        def cache_pop(key: KT, default: T) -> Union[T, VT]:
+            ...
+
         @synchronized
-        def cache_pop(key, default=None):
+        def cache_pop(key: KT, default: Optional[T] = None):
             node = cache.get(key, None)
             if node:
                 delete_node(node)
@@ -252,18 +306,18 @@ class LruCache:
                 return default
 
         @synchronized
-        def cache_del_multi(key):
+        def cache_del_multi(key: KT) -> None:
             """
             This will only work if constructed with cache_type=TreeCache
             """
             popped = cache.pop(key)
             if popped is None:
                 return
-            for leaf in enumerate_leaves(popped, keylen - len(key)):
+            for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
                 delete_node(leaf)
 
         @synchronized
-        def cache_clear():
+        def cache_clear() -> None:
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
@@ -274,7 +328,7 @@ class LruCache:
                 cached_cache_len[0] = 0
 
         @synchronized
-        def cache_contains(key):
+        def cache_contains(key: KT) -> bool:
             return key in cache
 
         self.sentinel = object()