summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2020-10-14 19:40:53 +0100
committerRichard van der Hoff <richard@matrix.org>2020-10-14 23:38:14 +0100
commit7eff59ec91e59140c375b43a6dac05b833ab0051 (patch)
treea5f570feda94700626c85a9bc3130586a52256d2 /synapse/util/caches/descriptors.py
parentMove additional tasks to the background worker, part 4 (#8513) (diff)
downloadsynapse-7eff59ec91e59140c375b43a6dac05b833ab0051.tar.xz
Add some more type annotations to Cache
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py81
1 files changed, 59 insertions, 22 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 98b34f2223..14458bc20f 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,12 +13,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 import functools
 import inspect
 import logging
 import threading
-from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
+from typing import (
+    Any,
+    Callable,
+    Generic,
+    Iterable,
+    MutableMapping,
+    Optional,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+)
 from weakref import WeakValueDictionary
 
 from prometheus_client import Gauge
@@ -38,6 +49,8 @@ logger = logging.getLogger(__name__)
 CacheKey = Union[Tuple, Any]
 
 F = TypeVar("F", bound=Callable[..., Any])
+KT = TypeVar("KT")
+VT = TypeVar("VT")
 
 
 class _CachedFunction(Generic[F]):
@@ -61,13 +74,19 @@ cache_pending_metric = Gauge(
     ["name"],
 )
 
-_CacheSentinel = object()
+
+class _Sentinel(enum.Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup.
+    sentinel = object()
 
 
 class CacheEntry:
     __slots__ = ["deferred", "callbacks", "invalidated"]
 
-    def __init__(self, deferred, callbacks):
+    def __init__(
+        self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
+    ):
         self.deferred = deferred
         self.callbacks = set(callbacks)
         self.invalidated = False
@@ -80,7 +99,13 @@ class CacheEntry:
             self.callbacks.clear()
 
 
-class Cache:
+class Cache(Generic[KT, VT]):
+    """Wraps an LruCache, adding support for Deferred results.
+
+    It expects that each entry added with set() will be a Deferred; likewise get()
+    may return an ObservableDeferred.
+    """
+
     __slots__ = (
         "cache",
         "name",
@@ -103,19 +128,23 @@ class Cache:
         Args:
             name: The name of the cache
             max_entries: Maximum amount of entries that the cache will hold
-            keylen: The length of the tuple used as the cache key
+            keylen: The length of the tuple used as the cache key. Ignored unless
+               `tree` is True.
             tree: Use a TreeCache instead of a dict as the underlying cache type
             iterable: If True, count each item in the cached object as an entry,
                 rather than each cached object
             apply_cache_factor_from_config: Whether cache factors specified in the
                 config file affect `max_entries`
-
-        Returns:
-            Cache
         """
         cache_type = TreeCache if tree else dict
-        self._pending_deferred_cache = cache_type()
 
+        # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
+        self._pending_deferred_cache = (
+            cache_type()
+        )  # type: MutableMapping[KT, CacheEntry]
+
+        # cache is used for completed results and maps to the result itself, rather than
+        # a Deferred.
         self.cache = LruCache(
             max_size=max_entries,
             keylen=keylen,
@@ -155,7 +184,13 @@ class Cache:
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
+    def get(
+        self,
+        key: KT,
+        default=_Sentinel.sentinel,
+        callback: Optional[Callable[[], None]] = None,
+        update_metrics: bool = True,
+    ):
         """Looks the key up in the caches.
 
         Args:
@@ -166,30 +201,32 @@ class Cache:
             update_metrics (bool): whether to update the cache hit rate metrics
 
         Returns:
-            Either an ObservableDeferred or the raw result
+            Either an ObservableDeferred or the result itself
         """
         callbacks = [callback] if callback else []
-        val = self._pending_deferred_cache.get(key, _CacheSentinel)
-        if val is not _CacheSentinel:
+        val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
+        if val is not _Sentinel.sentinel:
             val.callbacks.update(callbacks)
             if update_metrics:
                 self.metrics.inc_hits()
             return val.deferred
 
-        val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
-        if val is not _CacheSentinel:
+        val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
+        if val is not _Sentinel.sentinel:
             self.metrics.inc_hits()
             return val
 
         if update_metrics:
             self.metrics.inc_misses()
 
-        if default is _CacheSentinel:
+        if default is _Sentinel.sentinel:
             raise KeyError()
         else:
             return default
 
-    def set(self, key, value, callback=None):
+    def set(
+        self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None
+    ) -> ObservableDeferred:
         if not isinstance(value, defer.Deferred):
             raise TypeError("not a Deferred")
 
@@ -248,7 +285,7 @@ class Cache:
         observer.addCallbacks(cb, eb)
         return observable
 
-    def prefill(self, key, value, callback=None):
+    def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
         callbacks = [callback] if callback else []
         self.cache.set(key, value, callbacks=callbacks)
 
@@ -267,7 +304,7 @@ class Cache:
         if entry:
             entry.invalidate()
 
-    def invalidate_many(self, key):
+    def invalidate_many(self, key: KT):
         self.check_thread()
         if not isinstance(key, tuple):
             raise TypeError("The cache key must be a tuple not %r" % (type(key),))
@@ -275,7 +312,7 @@ class Cache:
 
         # if we have a pending lookup for this key, remove it from the
         # _pending_deferred_cache, as above
-        entry_dict = self._pending_deferred_cache.pop(key, None)
+        entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
         if entry_dict is not None:
             for entry in iterate_tree_cache_entry(entry_dict):
                 entry.invalidate()
@@ -396,7 +433,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             keylen=self.num_args,
             tree=self.tree,
             iterable=self.iterable,
-        )
+        )  # type: Cache[Tuple, Any]
 
         def get_cache_key_gen(args, kwargs):
             """Given some args/kwargs return a generator that resolves into