diff options
Diffstat (limited to 'synapse/util/caches')
-rw-r--r-- | synapse/util/caches/deferred_cache.py | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 4026e1f8fa..faeef75506 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -17,7 +17,16 @@ import enum import threading -from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast +from typing import ( + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + TypeVar, + Union, + cast, +) from prometheus_client import Gauge @@ -33,7 +42,7 @@ cache_pending_metric = Gauge( ["name"], ) - +T = TypeVar("T") KT = TypeVar("KT") VT = TypeVar("VT") @@ -119,21 +128,21 @@ class DeferredCache(Generic[KT, VT]): def get( self, key: KT, - default=_Sentinel.sentinel, callback: Optional[Callable[[], None]] = None, update_metrics: bool = True, - ): + ) -> Union[ObservableDeferred, VT]: """Looks the key up in the caches. Args: key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead callback(fn): Gets called when the entry in the cache is invalidated update_metrics (bool): whether to update the cache hit rate metrics Returns: Either an ObservableDeferred or the result itself + + Raises: + KeyError if the key is not found in the cache """ callbacks = [callback] if callback else [] val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) @@ -145,13 +154,19 @@ class DeferredCache(Generic[KT, VT]): m.inc_hits() return val.deferred - val = self.cache.get( - key, default, callbacks=callbacks, update_metrics=update_metrics + val2 = self.cache.get( + key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics ) - if val is _Sentinel.sentinel: + if val2 is _Sentinel.sentinel: raise KeyError() else: - return val + return val2 + + def get_immediate( + self, key: KT, default: T, update_metrics: bool = True + ) -> Union[VT, T]: + """If we have a *completed* cached value, return it.""" + return self.cache.get(key, default, update_metrics=update_metrics) def set( self, |