summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-07-27 13:57:29 +0100
committerErik Johnston <erik@matrix.org>2015-07-27 13:57:29 +0100
commit39e21ea51cf94f436f1f845f24c1b04f11b92c6f (patch)
treeb9517d390414bf0512c087c58ff061fd42b33b9c /synapse/storage
parentMerge pull request #205 from matrix-org/erikj/pick_largest_thumbnail (diff)
downloadsynapse-39e21ea51cf94f436f1f845f24c1b04f11b92c6f.tar.xz
Add support for using keyword arguments with cached functions
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py40
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/storage/push_rule.py8
-rw-r--r--synapse/storage/receipts.py5
-rw-r--r--synapse/storage/room.py5
-rw-r--r--synapse/storage/state.py5
6 files changed, 45 insertions, 23 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8f812f0fd7..f1265541ba 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -27,6 +27,7 @@ from twisted.internet import defer
 from collections import namedtuple, OrderedDict
 
 import functools
+import inspect
 import sys
 import time
 import threading
@@ -141,13 +142,28 @@ class CacheDescriptor(object):
     which can be used to insert values into the cache specifically, without
     calling the calculation function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
+    def __init__(self, orig, max_entries=1000, num_args=1, lru=False,
+                 inlineCallbacks=False):
         self.orig = orig
 
+        if inlineCallbacks:
+            self.function_to_call = defer.inlineCallbacks(orig)
+        else:
+            self.function_to_call = orig
+
         self.max_entries = max_entries
         self.num_args = num_args
         self.lru = lru
 
+        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+
+        if len(self.arg_names) < self.num_args:
+            raise Exception(
+                "Not enough explicit positional arguments to key off of for %r."
+                " (@cached cannot key off of *args or **kwars)"
+                % (orig.__name__,)
+            )
+
     def __get__(self, obj, objtype=None):
         cache = Cache(
             name=self.orig.__name__,
@@ -158,11 +174,13 @@ class CacheDescriptor(object):
 
         @functools.wraps(self.orig)
         @defer.inlineCallbacks
-        def wrapped(*keyargs):
+        def wrapped(*args, **kwargs):
+            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             try:
-                cached_result = cache.get(*keyargs[:self.num_args])
+                cached_result = cache.get(*keyargs)
                 if DEBUG_CACHES:
-                    actual_result = yield self.orig(obj, *keyargs)
+                    actual_result = yield self.function_to_call(obj, *args, **kwargs)
                     if actual_result != cached_result:
                         logger.error(
                             "Stale cache entry %s%r: cached: %r, actual %r",
@@ -177,9 +195,9 @@ class CacheDescriptor(object):
                 # while the SELECT is executing (SYN-369)
                 sequence = cache.sequence
 
-                ret = yield self.orig(obj, *keyargs)
+                ret = yield self.function_to_call(obj, *args, **kwargs)
 
-                cache.update(sequence, *keyargs[:self.num_args] + (ret,))
+                cache.update(sequence, *(keyargs + [ret]))
 
                 defer.returnValue(ret)
 
@@ -201,6 +219,16 @@ def cached(max_entries=1000, num_args=1, lru=False):
     )
 
 
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+    return lambda orig: CacheDescriptor(
+        orig,
+        max_entries=max_entries,
+        num_args=num_args,
+        lru=lru,
+        inlineCallbacks=True,
+    )
+
+
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 940a5f7e08..e3f98f0cde 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from _base import SQLBaseStore, cached
+from _base import SQLBaseStore, cachedInlineCallbacks
 
 from twisted.internet import defer
 
@@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
             desc="store_server_certificate",
         )
 
-    @cached()
-    @defer.inlineCallbacks
+    @cachedInlineCallbacks()
     def get_all_server_verify_keys(self, server_name):
         rows = yield self._simple_select_list(
             table="server_signature_keys",
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 4cac118d17..a220f3632e 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
 from twisted.internet import defer
 
 import logging
@@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
 
 
 class PushRuleStore(SQLBaseStore):
-    @cached()
-    @defer.inlineCallbacks
+    @cachedInlineCallbacks()
     def get_push_rules_for_user(self, user_name):
         rows = yield self._simple_select_list(
             table=PushRuleTable.table_name,
@@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(rows)
 
-    @cached()
-    @defer.inlineCallbacks
+    @cachedInlineCallbacks()
     def get_push_rules_enabled_for_user(self, user_name):
         results = yield self._simple_select_list(
             table=PushRuleEnableTable.table_name,
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 7a6af98d98..b79d6683ca 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
 
 from twisted.internet import defer
 
@@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_max_token(self)
 
-    @cached
-    @defer.inlineCallbacks
+    @cachedInlineCallbacks()
     def get_graph_receipts_for_room(self, room_id):
         """Get receipts for sending to remote servers.
         """
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 4612a8aa83..dd5bc2c8fb 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cachedInlineCallbacks
 
 import collections
 import logging
@@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
                 }
             )
 
-    @cached()
-    @defer.inlineCallbacks
+    @cachedInlineCallbacks()
     def get_room_name_and_aliases(self, room_id):
         def f(txn):
             sql = (
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 47bec65497..55c6d52890 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore, cached
+from ._base import SQLBaseStore, cached, cachedInlineCallbacks
 
 from twisted.internet import defer
 
@@ -189,8 +189,7 @@ class StateStore(SQLBaseStore):
         events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
-    @cached(num_args=3)
-    @defer.inlineCallbacks
+    @cachedInlineCallbacks(num_args=3)
     def get_current_state_for_key(self, room_id, event_type, state_key):
         def f(txn):
             sql = (