diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f8a07df6b8..861c24809c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
- # If we're passed a cache_context then we'll want to call its invalidate()
- # whenever we are invalidated
+ # If we're passed a cache_context then we'll want to call its
+ # invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
- # cached is a dict arg -> deferred, where deferred results in a
- # 2-tuple (`arg`, `result`)
results = {}
- cached_defers = {}
- missing = []
+
+ def update_results_dict(res, arg):
+ results[arg] = res
+
+ # list of deferreds to wait for
+ cached_defers = []
+
+ missing = set()
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
- def cache_get(arg):
- return cache.get(arg, callback=invalidate_callback)
+ def arg_to_cache_key(arg):
+ return arg
else:
- key = list(keyargs)
+ keylist = list(keyargs)
- def cache_get(arg):
- key[self.list_pos] = arg
- return cache.get(tuple(key), callback=invalidate_callback)
+ def arg_to_cache_key(arg):
+ keylist[self.list_pos] = arg
+ return tuple(keylist)
for arg in list_args:
try:
- res = cache_get(arg)
-
+ res = cache.get(arg_to_cache_key(arg),
+ callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
res = res.observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
- cached_defers[arg] = res
+ res.addCallback(update_results_dict, arg)
+ cached_defers.append(res)
else:
results[arg] = res.get_result()
except KeyError:
- missing.append(arg)
+ missing.add(arg)
if missing:
+ # we need an observable deferred for each entry in the list,
+ # which we put in the cache. Each deferred resolves with the
+ # relevant result for that key.
+ deferreds_map = {}
+ for arg in missing:
+ deferred = defer.Deferred()
+ deferreds_map[arg] = deferred
+ key = arg_to_cache_key(arg)
+ observable = ObservableDeferred(deferred)
+ cache.set(key, observable, callback=invalidate_callback)
+
+ def complete_all(res):
+ # the wrapped function has completed. It returns a
+ # a dict. We can now resolve the observable deferreds in
+ # the cache and update our own result map.
+ for e in missing:
+ val = res.get(e, None)
+ deferreds_map[e].callback(val)
+ results[e] = val
+
+ def errback(f):
+ # the wrapped function has failed. Invalidate any cache
+ # entries we're supposed to be populating, and fail
+ # their deferreds.
+ for e in missing:
+ key = arg_to_cache_key(e)
+ cache.invalidate(key)
+ deferreds_map[e].errback(f)
+
+ # return the failure, to propagate to our caller.
+ return f
+
args_to_call = dict(arg_dict)
- args_to_call[self.list_name] = missing
+ args_to_call[self.list_name] = list(missing)
- ret_d = defer.maybeDeferred(
+ cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call),
**args_to_call
- )
-
- ret_d = ObservableDeferred(ret_d)
-
- # We need to create deferreds for each arg in the list so that
- # we can insert the new deferred into the cache.
- for arg in missing:
- observer = ret_d.observe()
- observer.addCallback(lambda r, arg: r.get(arg, None), arg)
-
- observer = ObservableDeferred(observer)
-
- if num_args == 1:
- cache.set(
- arg, observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, arg)
- else:
- key = list(keyargs)
- key[self.list_pos] = arg
- cache.set(
- tuple(key), observer,
- callback=invalidate_callback
- )
-
- def invalidate(f, key):
- cache.invalidate(key)
- return f
- observer.addErrback(invalidate, tuple(key))
-
- res = observer.observe()
- res.addCallback(lambda r, arg: (arg, r), arg)
-
- cached_defers[arg] = res
+ ).addCallbacks(complete_all, errback))
if cached_defers:
- def update_results_dict(res):
- results.update(res)
- return results
-
- return logcontext.make_deferred_yieldable(defer.gatherResults(
- list(cached_defers.values()),
+ d = defer.gatherResults(
+ cached_defers,
consumeErrors=True,
- ).addCallback(update_results_dict).addErrback(
+ ).addCallbacks(
+ lambda _: results,
unwrapFirstError
- ))
+ )
+ return logcontext.make_deferred_yieldable(d)
else:
return results
@@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cache.
Args:
- cache (Cache): The underlying cache to use.
+ cached_method_name (str): The name of the single-item lookup method.
+ This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 581c6052ac..014edea971 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import string_types
+from six import binary_type, text_type
from canonicaljson import json
from frozendict import frozendict
@@ -26,7 +26,7 @@ def freeze(o):
if isinstance(o, frozendict):
return o
- if isinstance(o, string_types):
+ if isinstance(o, (binary_type, text_type)):
return o
try:
@@ -41,7 +41,7 @@ def unfreeze(o):
if isinstance(o, (dict, frozendict)):
return dict({k: unfreeze(v) for k, v in o.items()})
- if isinstance(o, string_types):
+ if isinstance(o, (binary_type, text_type)):
return o
try:
|