diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..10aff4d04a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from typing import (
Generic,
Hashable,
Iterable,
+ List,
Mapping,
Optional,
Sequence,
@@ -73,8 +74,10 @@ class _CacheDescriptorBase:
num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False,
+ name: Optional[str] = None,
):
self.orig = orig
+ self.name = name or orig.__name__
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
@@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache(
- cache_name=self.orig.__name__,
+ cache_name=self.name,
max_size=self.max_entries,
)
@@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
wrapped = cast(_CachedFunction, _wrapped)
wrapped.cache = cache
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
):
super().__init__(
orig,
num_args=num_args,
uncached_args=uncached_args,
cache_context=cache_context,
+ name=name,
)
if tree and self.num_args < 2:
@@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache(
- name=self.orig.__name__,
+ name=self.name,
max_entries=self.max_entries,
tree=self.tree,
iterable=self.iterable,
@@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
wrapped.cache = cache
wrapped.num_args = self.num_args
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
+ name: Optional[str] = None,
):
"""
Args:
@@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
- super().__init__(orig, num_args=num_args, uncached_args=None)
+ super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
self.list_name = list_name
@@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
- results = {}
-
- def update_results_dict(res: Any, arg: Hashable) -> None:
- results[arg] = res
-
- # list of deferreds to wait for
- cached_defers = []
-
- missing = set()
-
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
@@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
def arg_to_cache_key(arg: Hashable) -> Hashable:
return arg
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key
+
else:
keylist = list(keyargs)
@@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
keylist[self.list_pos] = arg
return tuple(keylist)
- for arg in list_args:
- try:
- res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
- if not res.called:
- res.addCallback(update_results_dict, arg)
- cached_defers.append(res)
- else:
- results[arg] = res.result
- except KeyError:
- missing.add(arg)
+ def cache_key_to_arg(key: tuple) -> Hashable:
+ return key[self.list_pos]
+
+ cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+ immediate_results, pending_deferred, missing = cache.get_bulk(
+ cache_keys, callback=invalidate_callback
+ )
+
+ results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+ cached_defers: List["defer.Deferred[Any]"] = []
+ if pending_deferred:
+
+ def update_results(r: Dict) -> None:
+ for k, v in r.items():
+ results[cache_key_to_arg(k)] = v
+
+ pending_deferred.addCallback(update_results)
+ cached_defers.append(pending_deferred)
if missing:
- # we need a deferred for each entry in the list,
- # which we put in the cache. Each deferred resolves with the
- # relevant result for that key.
- deferreds_map = {}
- for arg in missing:
- deferred: "defer.Deferred[Any]" = defer.Deferred()
- deferreds_map[arg] = deferred
- key = arg_to_cache_key(arg)
- cached_defers.append(
- cache.set(key, deferred, callback=invalidate_callback)
- )
+ cache_entry = cache.start_bulk_input(missing, invalidate_callback)
def complete_all(res: Dict[Hashable, Any]) -> None:
- # the wrapped function has completed. It returns a dict.
- # We can now update our own result map, and then resolve the
- # observable deferreds in the cache.
- for e, d1 in deferreds_map.items():
- val = res.get(e, None)
- # make sure we update the results map before running the
- # deferreds, because as soon as we run the last deferred, the
- # gatherResults() below will complete and return the result
- # dict to our caller.
- results[e] = val
- d1.callback(val)
+ missing_results = {}
+ for key in missing:
+ arg = cache_key_to_arg(key)
+ val = res.get(arg, None)
+
+ results[arg] = val
+ missing_results[key] = val
+
+ cache_entry.complete_bulk(cache, missing_results)
def errback_all(f: Failure) -> None:
- # the wrapped function has failed. Propagate the failure into
- # the cache, which will invalidate the entry, and cause the
- # relevant cached_deferreds to fail, which will propagate the
- # failure to our caller.
- for d1 in deferreds_map.values():
- d1.errback(f)
+ cache_entry.error_bulk(cache, missing, f)
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
+ args_to_call[self.list_name] = {
+ cache_key_to_arg(key) for key in missing
+ }
# dispatch the call, and attach the two handlers
- defer.maybeDeferred(
+ missing_d = defer.maybeDeferred(
preserve_fn(self.orig), **args_to_call
).addCallbacks(complete_all, errback_all)
+ cached_defers.append(missing_d)
if cached_defers:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
else:
return defer.succeed(results)
- obj.__dict__[self.orig.__name__] = wrapped
+ obj.__dict__[self.name] = wrapped
return wrapped
@@ -577,6 +571,7 @@ def cached(
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
+ name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
@@ -587,13 +582,18 @@ def cached(
cache_context=cache_context,
iterable=iterable,
prune_unread_entries=prune_unread_entries,
+ name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
def cachedList(
- *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+ *,
+ cached_method_name: str,
+ list_name: str,
+ num_args: Optional[int] = None,
+ name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
@@ -628,6 +628,7 @@ def cachedList(
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
+ name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
|