summary refs log tree commit diff
path: root/synapse/util/caches/descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r--synapse/util/caches/descriptors.py115
1 files changed, 58 insertions, 57 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..10aff4d04a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -25,6 +25,7 @@ from typing import (
     Generic,
     Hashable,
     Iterable,
+    List,
     Mapping,
     Optional,
     Sequence,
@@ -73,8 +74,10 @@ class _CacheDescriptorBase:
         num_args: Optional[int],
         uncached_args: Optional[Collection[str]] = None,
         cache_context: bool = False,
+        name: Optional[str] = None,
     ):
         self.orig = orig
+        self.name = name or orig.__name__
 
         arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
@@ -211,7 +214,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: LruCache[CacheKey, Any] = LruCache(
-            cache_name=self.orig.__name__,
+            cache_name=self.name,
             max_size=self.max_entries,
         )
 
@@ -241,7 +244,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
         wrapped = cast(_CachedFunction, _wrapped)
         wrapped.cache = cache
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -301,12 +304,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         cache_context: bool = False,
         iterable: bool = False,
         prune_unread_entries: bool = True,
+        name: Optional[str] = None,
     ):
         super().__init__(
             orig,
             num_args=num_args,
             uncached_args=uncached_args,
             cache_context=cache_context,
+            name=name,
         )
 
         if tree and self.num_args < 2:
@@ -321,7 +326,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
-            name=self.orig.__name__,
+            name=self.name,
             max_entries=self.max_entries,
             tree=self.tree,
             iterable=self.iterable,
@@ -372,7 +377,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped.cache = cache
         wrapped.num_args = self.num_args
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -393,6 +398,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
         cached_method_name: str,
         list_name: str,
         num_args: Optional[int] = None,
+        name: Optional[str] = None,
     ):
         """
         Args:
@@ -403,7 +409,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
         """
-        super().__init__(orig, num_args=num_args, uncached_args=None)
+        super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
 
         self.list_name = list_name
 
@@ -435,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            results = {}
-
-            def update_results_dict(res: Any, arg: Hashable) -> None:
-                results[arg] = res
-
-            # list of deferreds to wait for
-            cached_defers = []
-
-            missing = set()
-
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
@@ -452,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 def arg_to_cache_key(arg: Hashable) -> Hashable:
                     return arg
 
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key
+
             else:
                 keylist = list(keyargs)
 
@@ -459,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     keylist[self.list_pos] = arg
                     return tuple(keylist)
 
-            for arg in list_args:
-                try:
-                    res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
-                    if not res.called:
-                        res.addCallback(update_results_dict, arg)
-                        cached_defers.append(res)
-                    else:
-                        results[arg] = res.result
-                except KeyError:
-                    missing.add(arg)
+                def cache_key_to_arg(key: tuple) -> Hashable:
+                    return key[self.list_pos]
+
+            cache_keys = [arg_to_cache_key(arg) for arg in list_args]
+            immediate_results, pending_deferred, missing = cache.get_bulk(
+                cache_keys, callback=invalidate_callback
+            )
+
+            results = {cache_key_to_arg(key): v for key, v in immediate_results.items()}
+
+            cached_defers: List["defer.Deferred[Any]"] = []
+            if pending_deferred:
+
+                def update_results(r: Dict) -> None:
+                    for k, v in r.items():
+                        results[cache_key_to_arg(k)] = v
+
+                pending_deferred.addCallback(update_results)
+                cached_defers.append(pending_deferred)
 
             if missing:
-                # we need a deferred for each entry in the list,
-                # which we put in the cache. Each deferred resolves with the
-                # relevant result for that key.
-                deferreds_map = {}
-                for arg in missing:
-                    deferred: "defer.Deferred[Any]" = defer.Deferred()
-                    deferreds_map[arg] = deferred
-                    key = arg_to_cache_key(arg)
-                    cached_defers.append(
-                        cache.set(key, deferred, callback=invalidate_callback)
-                    )
+                cache_entry = cache.start_bulk_input(missing, invalidate_callback)
 
                 def complete_all(res: Dict[Hashable, Any]) -> None:
-                    # the wrapped function has completed. It returns a dict.
-                    # We can now update our own result map, and then resolve the
-                    # observable deferreds in the cache.
-                    for e, d1 in deferreds_map.items():
-                        val = res.get(e, None)
-                        # make sure we update the results map before running the
-                        # deferreds, because as soon as we run the last deferred, the
-                        # gatherResults() below will complete and return the result
-                        # dict to our caller.
-                        results[e] = val
-                        d1.callback(val)
+                    missing_results = {}
+                    for key in missing:
+                        arg = cache_key_to_arg(key)
+                        val = res.get(arg, None)
+
+                        results[arg] = val
+                        missing_results[key] = val
+
+                    cache_entry.complete_bulk(cache, missing_results)
 
                 def errback_all(f: Failure) -> None:
-                    # the wrapped function has failed. Propagate the failure into
-                    # the cache, which will invalidate the entry, and cause the
-                    # relevant cached_deferreds to fail, which will propagate the
-                    # failure to our caller.
-                    for d1 in deferreds_map.values():
-                        d1.errback(f)
+                    cache_entry.error_bulk(cache, missing, f)
 
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = {
+                    cache_key_to_arg(key) for key in missing
+                }
 
                 # dispatch the call, and attach the two handlers
-                defer.maybeDeferred(
+                missing_d = defer.maybeDeferred(
                     preserve_fn(self.orig), **args_to_call
                 ).addCallbacks(complete_all, errback_all)
+                cached_defers.append(missing_d)
 
             if cached_defers:
                 d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
@@ -525,7 +519,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             else:
                 return defer.succeed(results)
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -577,6 +571,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
     prune_unread_entries: bool = True,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     func = lambda orig: DeferredCacheDescriptor(
         orig,
@@ -587,13 +582,18 @@ def cached(
         cache_context=cache_context,
         iterable=iterable,
         prune_unread_entries=prune_unread_entries,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
 
 
 def cachedList(
-    *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+    *,
+    cached_method_name: str,
+    list_name: str,
+    num_args: Optional[int] = None,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
@@ -628,6 +628,7 @@ def cachedList(
         cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)