summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/cached_call.py27
-rw-r--r--synapse/util/caches/deferred_cache.py15
-rw-r--r--synapse/util/caches/descriptors.py2
3 files changed, 30 insertions, 14 deletions
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py

index 891bee0b33..e58dd91eda 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union from twisted.internet.defer import Deferred @@ -22,6 +22,10 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background TV = TypeVar("TV") +class _Sentinel(enum.Enum): + sentinel = object() + + class CachedCall(Generic[TV]): """A wrapper for asynchronous calls whose results should be shared @@ -65,7 +69,7 @@ class CachedCall(Generic[TV]): """ self._callable: Optional[Callable[[], Awaitable[TV]]] = f self._deferred: Optional[Deferred] = None - self._result: Union[None, Failure, TV] = None + self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel async def get(self) -> TV: """Kick off the call if necessary, and return the result""" @@ -78,8 +82,9 @@ class CachedCall(Generic[TV]): self._callable = None # once the deferred completes, store the result. We cannot simply leave the - # result in the deferred, since if it's a Failure, GCing the deferred - # would then log a critical error about unhandled Failures. + # result in the deferred, since `awaiting` a deferred destroys its result. + # (Also, if it's a Failure, GCing the deferred would log a critical error + # about unhandled Failures) def got_result(r): self._result = r @@ -92,13 +97,15 @@ class CachedCall(Generic[TV]): # and any eventual exception may not be reported. # we can now await the deferred, and once it completes, return the result. - await make_deferred_yieldable(self._deferred) + if isinstance(self._result, _Sentinel): + await make_deferred_yieldable(self._deferred) + assert not isinstance(self._result, _Sentinel) + + if isinstance(self._result, Failure): + self._result.raiseException() + raise AssertionError("unexpected return from Failure.raiseException") - # I *think* this is the easiest way to correctly raise a Failure without having - # to gut-wrench into the implementation of Deferred. - d = Deferred() - d.callback(self._result) - return await d + return self._result class RetryOnExceptionCachedCall(Generic[TV]): diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 8c6fafc677..b6456392cd 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py
@@ -16,7 +16,16 @@ import enum import threading -from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union +from typing import ( + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + TypeVar, + Union, + cast, +) from prometheus_client import Gauge @@ -166,7 +175,7 @@ class DeferredCache(Generic[KT, VT]): def set( self, key: KT, - value: defer.Deferred, + value: "defer.Deferred[VT]", callback: Optional[Callable[[], None]] = None, ) -> defer.Deferred: """Adds a new entry to the cache (or updates an existing one). @@ -214,7 +223,7 @@ class DeferredCache(Generic[KT, VT]): if value.called: result = value.result if not isinstance(result, failure.Failure): - self.cache.set(key, result, callbacks) + self.cache.set(key, cast(VT, result), callbacks) return value # otherwise, we'll add an entry to the _pending_deferred_cache for now, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1e8e6b1d01..1ca31e41ac 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -413,7 +413,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): # relevant result for that key. deferreds_map = {} for arg in missing: - deferred = defer.Deferred() + deferred: "defer.Deferred[Any]" = defer.Deferred() deferreds_map[arg] = deferred key = arg_to_cache_key(arg) cache.set(key, deferred, callback=invalidate_callback)