1 files changed, 28 insertions, 6 deletions
| diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1607978e29..eed60d567e 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -197,6 +197,7 @@ class _CacheDescriptorBase(object):
 
         arg_spec = inspect.getargspec(orig)
         all_args = arg_spec.args
+        self.arg_spec = arg_spec
 
         if "cache_context" in all_args:
             if not cache_context:
@@ -226,6 +227,14 @@ class _CacheDescriptorBase(object):
         self.num_args = num_args
         self.arg_names = all_args[1:num_args + 1]
 
+        if arg_spec.defaults:
+            self.arg_defaults = dict(zip(
+                all_args[-len(arg_spec.defaults):],
+                arg_spec.defaults
+            ))
+        else:
+            self.arg_defaults = {}
+
         if "cache_context" in self.arg_names:
             raise Exception(
                 "cache_context arg cannot be included among the cache keys"
@@ -289,18 +298,31 @@ class CacheDescriptor(_CacheDescriptorBase):
             iterable=self.iterable,
         )
 
+        def get_cache_key(args, kwargs):
+            """Given some args/kwargs return a generator that resolves into
+            the cache_key.
+
+            We loop through each arg name, looking up if its in the `kwargs`,
+            otherwise using the next argument in `args`. If there are no more
+            args then we try looking the arg name up in the defaults
+            """
+            pos = 0
+            for nm in self.arg_names:
+                if nm in kwargs:
+                    yield kwargs[nm]
+                elif pos < len(args):
+                    yield args[pos]
+                    pos += 1
+                else:
+                    yield self.arg_defaults[nm]
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add temp cache_context so inspect.getcallargs doesn't explode
-            if self.add_cache_context:
-                kwargs["cache_context"] = None
-
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            cache_key = tuple(get_cache_key(args, kwargs))
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
 |