diff options
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r-- | synapse/util/caches/descriptors.py | 64 |
1 files changed, 32 insertions, 32 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index df4fb156c2..1cdead02f1 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -18,6 +18,7 @@ import inspect import logging from typing import ( Any, + Awaitable, Callable, Dict, Generic, @@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given an iterable of keys it looks in the cache to find any hits, then passes - the tuple of missing keys to the wrapped function. + the set of missing keys to the wrapped function. - Once wrapped, the function returns a Deferred which resolves to the list - of results. + Once wrapped, the function returns a Deferred which resolves to a Dict mapping from + input key to output value. """ def __init__( self, - orig: Callable[..., Any], + orig: Callable[..., Awaitable[Dict]], cached_method_name: str, list_name: str, num_args: Optional[int] = None, @@ -385,13 +386,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): def __get__( self, obj: Optional[Any], objtype: Optional[Type] = None - ) -> Callable[..., Any]: + ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]: cached_method = getattr(obj, self.cached_method_name) cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args @functools.wraps(self.orig) - def wrapped(*args: Any, **kwargs: Any) -> Any: + def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": # If we're passed a cache_context then we'll want to call its # invalidate() whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -444,39 +445,38 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): deferred: "defer.Deferred[Any]" = defer.Deferred() deferreds_map[arg] = deferred key = arg_to_cache_key(arg) - cache.set(key, deferred, callback=invalidate_callback) + cached_defers.append( + cache.set(key, deferred, callback=invalidate_callback) + ) def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a - # a dict. We can now resolve the observable deferreds in - # the cache and update our own result map. - for e in missing: + # the wrapped function has completed. It returns a dict. + # We can now update our own result map, and then resolve the + # observable deferreds in the cache. + for e, d1 in deferreds_map.items(): val = res.get(e, None) - deferreds_map[e].callback(val) + # make sure we update the results map before running the + # deferreds, because as soon as we run the last deferred, the + # gatherResults() below will complete and return the result + # dict to our caller. results[e] = val + d1.callback(val) - def errback(f: Failure) -> Failure: - # the wrapped function has failed. Invalidate any cache - # entries we're supposed to be populating, and fail - # their deferreds. - for e in missing: - key = arg_to_cache_key(e) - cache.invalidate(key) - deferreds_map[e].errback(f) - - # return the failure, to propagate to our caller. - return f + def errback_all(f: Failure) -> None: + # the wrapped function has failed. Propagate the failure into + # the cache, which will invalidate the entry, and cause the + # relevant cached_deferreds to fail, which will propagate the + # failure to our caller. + for d1 in deferreds_map.values(): + d1.errback(f) args_to_call = dict(arg_dict) - # copy the missing set before sending it to the callee, to guard against - # modification. - args_to_call[self.list_name] = tuple(missing) - - cached_defers.append( - defer.maybeDeferred( - preserve_fn(self.orig), **args_to_call - ).addCallbacks(complete_all, errback) - ) + args_to_call[self.list_name] = missing + + # dispatch the call, and attach the two handlers + defer.maybeDeferred( + preserve_fn(self.orig), **args_to_call + ).addCallbacks(complete_all, errback_all) if cached_defers: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( |