diff options
Diffstat (limited to 'synapse/util/caches/descriptors.py')
-rw-r--r-- | synapse/util/caches/descriptors.py | 89 |
1 files changed, 39 insertions, 50 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 9d4bc89edb..10aff4d04a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -25,6 +25,7 @@ from typing import ( Generic, Hashable, Iterable, + List, Mapping, Optional, Sequence, @@ -440,16 +441,6 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] list_args = arg_dict[self.list_name] - results = {} - - def update_results_dict(res: Any, arg: Hashable) -> None: - results[arg] = res - - # list of deferreds to wait for - cached_defers = [] - - missing = set() - # If the cache takes a single arg then that is used as the key, # otherwise a tuple is used. if num_args == 1: @@ -457,6 +448,9 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): def arg_to_cache_key(arg: Hashable) -> Hashable: return arg + def cache_key_to_arg(key: tuple) -> Hashable: + return key + else: keylist = list(keyargs) @@ -464,58 +458,53 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): keylist[self.list_pos] = arg return tuple(keylist) - for arg in list_args: - try: - res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) - if not res.called: - res.addCallback(update_results_dict, arg) - cached_defers.append(res) - else: - results[arg] = res.result - except KeyError: - missing.add(arg) + def cache_key_to_arg(key: tuple) -> Hashable: + return key[self.list_pos] + + cache_keys = [arg_to_cache_key(arg) for arg in list_args] + immediate_results, pending_deferred, missing = cache.get_bulk( + cache_keys, callback=invalidate_callback + ) + + results = {cache_key_to_arg(key): v for key, v in immediate_results.items()} + + cached_defers: List["defer.Deferred[Any]"] = [] + if pending_deferred: + + def update_results(r: Dict) -> None: + for k, v in r.items(): + results[cache_key_to_arg(k)] = v + + pending_deferred.addCallback(update_results) + cached_defers.append(pending_deferred) if missing: - # we need a deferred for each entry in the list, - # which we put in the cache. Each deferred resolves with the - # relevant result for that key. - deferreds_map = {} - for arg in missing: - deferred: "defer.Deferred[Any]" = defer.Deferred() - deferreds_map[arg] = deferred - key = arg_to_cache_key(arg) - cached_defers.append( - cache.set(key, deferred, callback=invalidate_callback) - ) + cache_entry = cache.start_bulk_input(missing, invalidate_callback) def complete_all(res: Dict[Hashable, Any]) -> None: - # the wrapped function has completed. It returns a dict. - # We can now update our own result map, and then resolve the - # observable deferreds in the cache. - for e, d1 in deferreds_map.items(): - val = res.get(e, None) - # make sure we update the results map before running the - # deferreds, because as soon as we run the last deferred, the - # gatherResults() below will complete and return the result - # dict to our caller. - results[e] = val - d1.callback(val) + missing_results = {} + for key in missing: + arg = cache_key_to_arg(key) + val = res.get(arg, None) + + results[arg] = val + missing_results[key] = val + + cache_entry.complete_bulk(cache, missing_results) def errback_all(f: Failure) -> None: - # the wrapped function has failed. Propagate the failure into - # the cache, which will invalidate the entry, and cause the - # relevant cached_deferreds to fail, which will propagate the - # failure to our caller. - for d1 in deferreds_map.values(): - d1.errback(f) + cache_entry.error_bulk(cache, missing, f) args_to_call = dict(arg_dict) - args_to_call[self.list_name] = missing + args_to_call[self.list_name] = { + cache_key_to_arg(key) for key in missing + } # dispatch the call, and attach the two handlers - defer.maybeDeferred( + missing_d = defer.maybeDeferred( preserve_fn(self.orig), **args_to_call ).addCallbacks(complete_all, errback_all) + cached_defers.append(missing_d) if cached_defers: d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( |