summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/_base.py40
1 files changed, 27 insertions, 13 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 9125bb1198..f483bd1520 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -54,13 +54,12 @@ cache_counter = metrics.register_cache(
 
 
 # TODO(paul):
-#  * more generic key management
 #  * consider other eviction strategies - LRU?
-def cached(max_entries=1000):
+def cached(max_entries=1000, num_args=1):
     """ A method decorator that applies a memoizing cache around the function.
 
-    The function is presumed to take one additional argument, which is used as
-    the key for the cache. Cache hits are served directly from the cache;
+    The function is presumed to take zero or more arguments, which are used in
+    a tuple as the key for the cache. Hits are served directly from the cache;
     misses use the function body to generate the value.
 
     The wrapped function has an additional member, a callable called
@@ -76,26 +75,41 @@ def cached(max_entries=1000):
 
         caches_by_name[name] = cache
 
-        def prefill(key, value):
+        def prefill(*args):  # because I can't  *keyargs, value
+            keyargs = args[:-1]
+            value = args[-1]
+
+            if len(keyargs) != num_args:
+                raise ValueError("Expected a call to have %d arguments", num_args)
+
             while len(cache) > max_entries:
                 cache.popitem(last=False)
 
-            cache[key] = value
+            cache[keyargs] = value
 
         @functools.wraps(orig)
         @defer.inlineCallbacks
-        def wrapped(self, key):
-            if key in cache:
+        def wrapped(self, *keyargs):
+            if len(keyargs) != num_args:
+                raise ValueError("Expected a call to have %d arguments", num_args)
+
+            if keyargs in cache:
                 cache_counter.inc_hits(name)
-                defer.returnValue(cache[key])
+                defer.returnValue(cache[keyargs])
 
             cache_counter.inc_misses(name)
-            ret = yield orig(self, key)
-            prefill(key, ret)
+            ret = yield orig(self, *keyargs)
+
+            prefill_args = keyargs + (ret,)
+            prefill(*prefill_args)
+
             defer.returnValue(ret)
 
-        def invalidate(key):
-            cache.pop(key, None)
+        def invalidate(*keyargs):
+            if len(keyargs) != num_args:
+                raise ValueError("Expected a call to have %d arguments", num_args)
+
+            cache.pop(keyargs, None)
 
         wrapped.invalidate = invalidate
         wrapped.prefill = prefill