diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8f812f0fd7..f1265541ba 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -27,6 +27,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
+import inspect
import sys
import time
import threading
@@ -141,13 +142,28 @@ class CacheDescriptor(object):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
- def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
+ def __init__(self, orig, max_entries=1000, num_args=1, lru=False,
+ inlineCallbacks=False):
self.orig = orig
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
+ self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+
+ if len(self.arg_names) < self.num_args:
+ raise Exception(
+ "Not enough explicit positional arguments to key off of for %r."
+ " (@cached cannot key off of *args or **kwars)"
+ % (orig.__name__,)
+ )
+
def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__,
@@ -158,11 +174,13 @@ class CacheDescriptor(object):
@functools.wraps(self.orig)
@defer.inlineCallbacks
- def wrapped(*keyargs):
+ def wrapped(*args, **kwargs):
+ arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+ keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
try:
- cached_result = cache.get(*keyargs[:self.num_args])
+ cached_result = cache.get(*keyargs)
if DEBUG_CACHES:
- actual_result = yield self.orig(obj, *keyargs)
+ actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
@@ -177,9 +195,9 @@ class CacheDescriptor(object):
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
- ret = yield self.orig(obj, *keyargs)
+ ret = yield self.function_to_call(obj, *args, **kwargs)
- cache.update(sequence, *keyargs[:self.num_args] + (ret,))
+ cache.update(sequence, *(keyargs + [ret]))
defer.returnValue(ret)
@@ -201,6 +219,16 @@ def cached(max_entries=1000, num_args=1, lru=False):
)
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru,
+ inlineCallbacks=True,
+ )
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
|