summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/__init__.py4
-rw-r--r--synapse/util/caches/cached_call.py6
-rw-r--r--synapse/util/caches/deferred_cache.py12
-rw-r--r--synapse/util/caches/descriptors.py36
-rw-r--r--synapse/util/caches/dictionary_cache.py6
-rw-r--r--synapse/util/caches/expiringcache.py4
-rw-r--r--synapse/util/caches/lrucache.py11
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/caches/stream_change_cache.py6
-rw-r--r--synapse/util/caches/treecache.py3
-rw-r--r--synapse/util/caches/ttlcache.py6
11 files changed, 47 insertions, 49 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index ca36f07c20..9012034b7a 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -29,8 +29,8 @@ logger = logging.getLogger(__name__)
 TRACK_MEMORY_USAGE = False
 
 
-caches_by_name = {}  # type: Dict[str, Sized]
-collectors_by_name = {}  # type: Dict[str, CacheMetric]
+caches_by_name: Dict[str, Sized] = {}
+collectors_by_name: Dict[str, "CacheMetric"] = {}
 
 cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
 cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index a301c9e89b..891bee0b33 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -63,9 +63,9 @@ class CachedCall(Generic[TV]):
             f: The underlying function. Only one call to this function will be alive
                 at once (per instance of CachedCall)
         """
-        self._callable = f  # type: Optional[Callable[[], Awaitable[TV]]]
-        self._deferred = None  # type: Optional[Deferred]
-        self._result = None  # type: Union[None, Failure, TV]
+        self._callable: Optional[Callable[[], Awaitable[TV]]] = f
+        self._deferred: Optional[Deferred] = None
+        self._result: Union[None, Failure, TV] = None
 
     async def get(self) -> TV:
         """Kick off the call if necessary, and return the result"""
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 1044139119..8c6fafc677 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -80,25 +80,25 @@ class DeferredCache(Generic[KT, VT]):
         cache_type = TreeCache if tree else dict
 
         # _pending_deferred_cache maps from the key value to a `CacheEntry` object.
-        self._pending_deferred_cache = (
-            cache_type()
-        )  # type: Union[TreeCache, MutableMapping[KT, CacheEntry]]
+        self._pending_deferred_cache: Union[
+            TreeCache, "MutableMapping[KT, CacheEntry]"
+        ] = cache_type()
 
         def metrics_cb():
             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
 
         # cache is used for completed results and maps to the result itself, rather than
         # a Deferred.
-        self.cache = LruCache(
+        self.cache: LruCache[KT, VT] = LruCache(
             max_size=max_entries,
             cache_name=name,
             cache_type=cache_type,
             size_callback=(lambda d: len(d) or 1) if iterable else None,
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
-        )  # type: LruCache[KT, VT]
+        )
 
-        self.thread = None  # type: Optional[threading.Thread]
+        self.thread: Optional[threading.Thread] = None
 
     @property
     def max_entries(self):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index d77e8edeea..1e8e6b1d01 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -46,17 +46,17 @@ F = TypeVar("F", bound=Callable[..., Any])
 
 
 class _CachedFunction(Generic[F]):
-    invalidate = None  # type: Any
-    invalidate_all = None  # type: Any
-    prefill = None  # type: Any
-    cache = None  # type: Any
-    num_args = None  # type: Any
+    invalidate: Any = None
+    invalidate_all: Any = None
+    prefill: Any = None
+    cache: Any = None
+    num_args: Any = None
 
-    __name__ = None  # type: str
+    __name__: str
 
     # Note: This function signature is actually fiddled with by the synapse mypy
     # plugin to a) make it a bound method, and b) remove any `cache_context` arg.
-    __call__ = None  # type: F
+    __call__: F
 
 
 class _CacheDescriptorBase:
@@ -115,8 +115,8 @@ class _CacheDescriptorBase:
 
 
 class _LruCachedFunction(Generic[F]):
-    cache = None  # type: LruCache[CacheKey, Any]
-    __call__ = None  # type: F
+    cache: LruCache[CacheKey, Any]
+    __call__: F
 
 
 def lru_cache(
@@ -180,10 +180,10 @@ class LruCacheDescriptor(_CacheDescriptorBase):
         self.max_entries = max_entries
 
     def __get__(self, obj, owner):
-        cache = LruCache(
+        cache: LruCache[CacheKey, Any] = LruCache(
             cache_name=self.orig.__name__,
             max_size=self.max_entries,
-        )  # type: LruCache[CacheKey, Any]
+        )
 
         get_cache_key = self.cache_key_builder
         sentinel = LruCacheDescriptor._Sentinel.sentinel
@@ -271,12 +271,12 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
 
     def __get__(self, obj, owner):
-        cache = DeferredCache(
+        cache: DeferredCache[CacheKey, Any] = DeferredCache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             tree=self.tree,
             iterable=self.iterable,
-        )  # type: DeferredCache[CacheKey, Any]
+        )
 
         get_cache_key = self.cache_key_builder
 
@@ -359,7 +359,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj, objtype=None):
         cached_method = getattr(obj, self.cached_method_name)
-        cache = cached_method.cache  # type: DeferredCache[CacheKey, Any]
+        cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
@@ -472,15 +472,15 @@ class _CacheContext:
 
     Cache = Union[DeferredCache, LruCache]
 
-    _cache_context_objects = (
-        WeakValueDictionary()
-    )  # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext]
+    _cache_context_objects: """WeakValueDictionary[
+        Tuple["_CacheContext.Cache", CacheKey], "_CacheContext"
+    ]""" = WeakValueDictionary()
 
     def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None:
         self._cache = cache
         self._cache_key = cache_key
 
-    def invalidate(self):  # type: () -> None
+    def invalidate(self) -> None:
         """Invalidates the cache entry referred to by the context."""
         self._cache.invalidate(self._cache_key)
 
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 56d94d96ce..3f852edd7f 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -62,13 +62,13 @@ class DictionaryCache(Generic[KT, DKT]):
     """
 
     def __init__(self, name: str, max_entries: int = 1000):
-        self.cache = LruCache(
+        self.cache: LruCache[KT, DictionaryEntry] = LruCache(
             max_size=max_entries, cache_name=name, size_callback=len
-        )  # type: LruCache[KT, DictionaryEntry]
+        )
 
         self.name = name
         self.sequence = 0
-        self.thread = None  # type: Optional[threading.Thread]
+        self.thread: Optional[threading.Thread] = None
 
     def check_thread(self) -> None:
         expected_thread = self.thread
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index ac47a31cd7..bde16b8577 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -27,7 +27,7 @@ from synapse.util.caches import register_cache
 logger = logging.getLogger(__name__)
 
 
-SENTINEL = object()  # type: Any
+SENTINEL: Any = object()
 
 
 T = TypeVar("T")
@@ -71,7 +71,7 @@ class ExpiringCache(Generic[KT, VT]):
         self._expiry_ms = expiry_ms
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache = OrderedDict()  # type: OrderedDict[KT, _CacheEntry]
+        self._cache: OrderedDict[KT, _CacheEntry] = OrderedDict()
 
         self.iterable = iterable
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 4b9d0433ff..5c65d187b6 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -90,8 +90,7 @@ def enumerate_leaves(node, depth):
         yield node
     else:
         for n in node.values():
-            for m in enumerate_leaves(n, depth - 1):
-                yield m
+            yield from enumerate_leaves(n, depth - 1)
 
 
 P = TypeVar("P")
@@ -226,7 +225,7 @@ class _Node:
         # footprint down. Storing `None` is free as its a singleton, while empty
         # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
         # thing and used sets).
-        self.callbacks = None  # type: Optional[List[Callable[[], None]]]
+        self.callbacks: Optional[List[Callable[[], None]]] = None
 
         self.add_callbacks(callbacks)
 
@@ -362,15 +361,15 @@ class LruCache(Generic[KT, VT]):
 
         # register_cache might call our "set_cache_factor" callback; there's nothing to
         # do yet when we get resized.
-        self._on_resize = None  # type: Optional[Callable[[],None]]
+        self._on_resize: Optional[Callable[[], None]] = None
 
         if cache_name is not None:
-            metrics = register_cache(
+            metrics: Optional[CacheMetric] = register_cache(
                 "lru_cache",
                 cache_name,
                 self,
                 collect_callback=metrics_collection_callback,
-            )  # type: Optional[CacheMetric]
+            )
         else:
             metrics = None
 
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 34c662c4db..ed7204336f 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -66,7 +66,7 @@ class ResponseCache(Generic[KV]):
         # This is poorly-named: it includes both complete and incomplete results.
         # We keep complete results rather than switching to absolute values because
         # that makes it easier to cache Failure results.
-        self.pending_result_cache = {}  # type: Dict[KV, ObservableDeferred]
+        self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
 
         self.clock = clock
         self.timeout_sec = timeout_ms / 1000.0
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index e81e468899..3a41a8baa6 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -45,10 +45,10 @@ class StreamChangeCache:
     ):
         self._original_max_size = max_size
         self._max_size = math.floor(max_size)
-        self._entity_to_key = {}  # type: Dict[EntityType, int]
+        self._entity_to_key: Dict[EntityType, int] = {}
 
         # map from stream id to the a set of entities which changed at that stream id.
-        self._cache = SortedDict()  # type: SortedDict[int, Set[EntityType]]
+        self._cache: SortedDict[int, Set[EntityType]] = SortedDict()
 
         # the earliest stream_pos for which we can reliably answer
         # get_all_entities_changed. In other words, one less than the earliest
@@ -155,7 +155,7 @@ class StreamChangeCache:
         if stream_pos < self._earliest_known_stream_pos:
             return None
 
-        changed_entities = []  # type: List[EntityType]
+        changed_entities: List[EntityType] = []
 
         for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
             changed_entities.extend(self._cache[k])
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index a6df81ebff..4138931e7b 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -138,7 +138,6 @@ def iterate_tree_cache_entry(d):
     """
     if isinstance(d, TreeCacheNode):
         for value_d in d.values():
-            for value in iterate_tree_cache_entry(value_d):
-                yield value
+            yield from iterate_tree_cache_entry(value_d)
     else:
         yield d
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index c276107d56..46afe3f934 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -23,7 +23,7 @@ from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
-SENTINEL = object()  # type: Any
+SENTINEL: Any = object()
 
 T = TypeVar("T")
 KT = TypeVar("KT")
@@ -35,10 +35,10 @@ class TTLCache(Generic[KT, VT]):
 
     def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
         # map from key to _CacheEntry
-        self._data = {}  # type: Dict[KT, _CacheEntry]
+        self._data: Dict[KT, _CacheEntry] = {}
 
         # the _CacheEntries, sorted by expiry time
-        self._expiry_list = SortedList()  # type: SortedList[_CacheEntry]
+        self._expiry_list: SortedList[_CacheEntry] = SortedList()
 
         self._timer = timer