diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index df4fb156c2..1cdead02f1 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,7 @@ import inspect
import logging
from typing import (
Any,
+ Awaitable,
Callable,
Dict,
Generic,
@@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
"""Wraps an existing cache to support bulk fetching of keys.
Given an iterable of keys it looks in the cache to find any hits, then passes
- the tuple of missing keys to the wrapped function.
+ the set of missing keys to the wrapped function.
- Once wrapped, the function returns a Deferred which resolves to the list
- of results.
+ Once wrapped, the function returns a Deferred which resolves to a Dict mapping from
+ input key to output value.
"""
def __init__(
self,
- orig: Callable[..., Any],
+ orig: Callable[..., Awaitable[Dict]],
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
@@ -385,13 +386,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def __get__(
self, obj: Optional[Any], objtype: Optional[Type] = None
- ) -> Callable[..., Any]:
+ ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]:
cached_method = getattr(obj, self.cached_method_name)
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
@functools.wraps(self.orig)
- def wrapped(*args: Any, **kwargs: Any) -> Any:
+ def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
# If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -444,39 +445,38 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
deferred: "defer.Deferred[Any]" = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
- cache.set(key, deferred, callback=invalidate_callback)
+ cached_defers.append(
+ cache.set(key, deferred, callback=invalidate_callback)
+ )
def complete_all(res: Dict[Hashable, Any]) -> None:
- # the wrapped function has completed. It returns a
- # a dict. We can now resolve the observable deferreds in
- # the cache and update our own result map.
- for e in missing:
+ # the wrapped function has completed. It returns a dict.
+ # We can now update our own result map, and then resolve the
+ # observable deferreds in the cache.
+ for e, d1 in deferreds_map.items():
val = res.get(e, None)
- deferreds_map[e].callback(val)
+ # make sure we update the results map before running the
+ # deferreds, because as soon as we run the last deferred, the
+ # gatherResults() below will complete and return the result
+ # dict to our caller.
results[e] = val
+ d1.callback(val)
- def errback(f: Failure) -> Failure:
- # the wrapped function has failed. Invalidate any cache
- # entries we're supposed to be populating, and fail
- # their deferreds.
- for e in missing:
- key = arg_to_cache_key(e)
- cache.invalidate(key)
- deferreds_map[e].errback(f)
-
- # return the failure, to propagate to our caller.
- return f
+ def errback_all(f: Failure) -> None:
+ # the wrapped function has failed. Propagate the failure into
+ # the cache, which will invalidate the entry, and cause the
+ # relevant cached_deferreds to fail, which will propagate the
+ # failure to our caller.
+ for d1 in deferreds_map.values():
+ d1.errback(f)
args_to_call = dict(arg_dict)
- # copy the missing set before sending it to the callee, to guard against
- # modification.
- args_to_call[self.list_name] = tuple(missing)
-
- cached_defers.append(
- defer.maybeDeferred(
- preserve_fn(self.orig), **args_to_call
- ).addCallbacks(complete_all, errback)
- )
+ args_to_call[self.list_name] = missing
+
+ # dispatch the call, and attach the two handlers
+ defer.maybeDeferred(
+ preserve_fn(self.orig), **args_to_call
+ ).addCallbacks(complete_all, errback_all)
if cached_defers:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
|