summary refs log tree commit diff
path: root/synapse/util/caches/expiringcache.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/expiringcache.py')
-rw-r--r--synapse/util/caches/expiringcache.py83
1 files changed, 51 insertions, 32 deletions
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index e15f7ee698..4dc3477e89 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -15,40 +15,50 @@
 
 import logging
 from collections import OrderedDict
+from typing import Any, Generic, Optional, TypeVar, Union, overload
+
+import attr
+from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
 from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
 
-SENTINEL = object()
+SENTINEL = object()  # type: Any
+
 
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
 
-class ExpiringCache:
+
+class ExpiringCache(Generic[KT, VT]):
     def __init__(
         self,
-        cache_name,
-        clock,
-        max_len=0,
-        expiry_ms=0,
-        reset_expiry_on_get=False,
-        iterable=False,
+        cache_name: str,
+        clock: Clock,
+        max_len: int = 0,
+        expiry_ms: int = 0,
+        reset_expiry_on_get: bool = False,
+        iterable: bool = False,
     ):
         """
         Args:
-            cache_name (str): Name of this cache, used for logging.
-            clock (Clock)
-            max_len (int): Max size of dict. If the dict grows larger than this
+            cache_name: Name of this cache, used for logging.
+            clock
+            max_len: Max size of dict. If the dict grows larger than this
                 then the oldest items get automatically evicted. Default is 0,
                 which indicates there is no max limit.
-            expiry_ms (int): How long before an item is evicted from the cache
+            expiry_ms: How long before an item is evicted from the cache
                 in milliseconds. Default is 0, indicating items never get
                 evicted based on time.
-            reset_expiry_on_get (bool): If true, will reset the expiry time for
+            reset_expiry_on_get: If true, will reset the expiry time for
                 an item on access. Defaults to False.
-            iterable (bool): If true, the size is calculated by summing the
+            iterable: If true, the size is calculated by summing the
                 sizes of all entries, rather than the number of entries.
         """
         self._cache_name = cache_name
@@ -62,7 +72,7 @@ class ExpiringCache:
         self._expiry_ms = expiry_ms
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache = OrderedDict()
+        self._cache = OrderedDict()  # type: OrderedDict[KT, _CacheEntry]
 
         self.iterable = iterable
 
@@ -79,12 +89,12 @@ class ExpiringCache:
 
         self._clock.looping_call(f, self._expiry_ms / 2)
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: KT, value: VT) -> None:
         now = self._clock.time_msec()
         self._cache[key] = _CacheEntry(now, value)
         self.evict()
 
-    def evict(self):
+    def evict(self) -> None:
         # Evict if there are now too many items
         while self._max_size and len(self) > self._max_size:
             _key, value = self._cache.popitem(last=False)
@@ -93,7 +103,7 @@ class ExpiringCache:
             else:
                 self.metrics.inc_evictions()
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: KT) -> VT:
         try:
             entry = self._cache[key]
             self.metrics.inc_hits()
@@ -106,7 +116,7 @@ class ExpiringCache:
 
         return entry.value
 
-    def pop(self, key, default=SENTINEL):
+    def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
         """Removes and returns the value with the given key from the cache.
 
         If the key isn't in the cache then `default` will be returned if
@@ -115,29 +125,40 @@ class ExpiringCache:
         Identical functionality to `dict.pop(..)`.
         """
 
-        value = self._cache.pop(key, default)
+        value = self._cache.pop(key, SENTINEL)
+        # The key was not found.
         if value is SENTINEL:
-            raise KeyError(key)
+            if default is SENTINEL:
+                raise KeyError(key)
+            return default
 
-        return value
+        return value.value
 
-    def __contains__(self, key):
+    def __contains__(self, key: KT) -> bool:
         return key in self._cache
 
-    def get(self, key, default=None):
+    @overload
+    def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
+        ...
+
+    @overload
+    def get(self, key: KT, default: T) -> Union[VT, T]:
+        ...
+
+    def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
         try:
             return self[key]
         except KeyError:
             return default
 
-    def setdefault(self, key, value):
+    def setdefault(self, key: KT, value: VT) -> VT:
         try:
             return self[key]
         except KeyError:
             self[key] = value
             return value
 
-    def _prune_cache(self):
+    def _prune_cache(self) -> None:
         if not self._expiry_ms:
             # zero expiry time means don't expire. This should never get called
             # since we have this check in start too.
@@ -166,7 +187,7 @@ class ExpiringCache:
             len(self),
         )
 
-    def __len__(self):
+    def __len__(self) -> int:
         if self.iterable:
             return sum(len(entry.value) for entry in self._cache.values())
         else:
@@ -190,9 +211,7 @@ class ExpiringCache:
         return False
 
 
+@attr.s(slots=True)
 class _CacheEntry:
-    __slots__ = ["time", "value"]
-
-    def __init__(self, time, value):
-        self.time = time
-        self.value = value
+    time = attr.ib(type=int)
+    value = attr.ib()