summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/_base.py40
-rw-r--r--synapse/storage/events.py1
-rw-r--r--synapse/storage/stream.py3
3 files changed, 30 insertions, 14 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0aab9a8af4..0ada6029fa 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -54,13 +54,12 @@ cache_counter = metrics.register_cache(
 
 
 # TODO(paul):
-#  * more generic key management
 #  * consider other eviction strategies - LRU?
-def cached(max_entries=1000):
+def cached(max_entries=1000, num_args=1):
     """ A method decorator that applies a memoizing cache around the function.
 
-    The function is presumed to take one additional argument, which is used as
-    the key for the cache. Cache hits are served directly from the cache;
+    The function is presumed to take zero or more arguments, which are used in
+    a tuple as the key for the cache. Hits are served directly from the cache;
     misses use the function body to generate the value.
 
     The wrapped function has an additional member, a callable called
@@ -76,26 +75,41 @@ def cached(max_entries=1000):
 
         caches_by_name[name] = cache
 
-        def prefill(key, value):
+        def prefill(*args):  # because I can't  *keyargs, value
+            keyargs = args[:-1]
+            value = args[-1]
+
+            if len(keyargs) != num_args:
+                raise ValueError("Expected a call to have %d arguments", num_args)
+
             while len(cache) > max_entries:
                 cache.popitem(last=False)
 
-            cache[key] = value
+            cache[keyargs] = value
 
         @functools.wraps(orig)
         @defer.inlineCallbacks
-        def wrapped(self, key):
-            if key in cache:
+        def wrapped(self, *keyargs):
+            if len(keyargs) != num_args:
+                raise ValueError("Expected a call to have %d arguments", num_args)
+
+            if keyargs in cache:
                 cache_counter.inc_hits(name)
-                defer.returnValue(cache[key])
+                defer.returnValue(cache[keyargs])
 
             cache_counter.inc_misses(name)
-            ret = yield orig(self, key)
-            prefill(key, ret)
+            ret = yield orig(self, *keyargs)
+
+            prefill_args = keyargs + (ret,)
+            prefill(*prefill_args)
+
             defer.returnValue(ret)
 
-        def invalidate(key):
-            cache.pop(key, None)
+        def invalidate(*keyargs):
+            if len(keyargs) != num_args:
+                raise ValueError("Expected a call to have %d arguments", num_args)
+
+            cache.pop(keyargs, None)
 
         wrapped.invalidate = invalidate
         wrapped.prefill = prefill
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index b295dc5b27..a86230d92c 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -52,6 +52,7 @@ class EventsStore(SQLBaseStore):
                 is_new_state=is_new_state,
                 current_state=current_state,
             )
+            self.get_room_events_max_id.invalidate()
         except _RollbackButIsFineException:
             pass
 
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index df234efdff..66f307e640 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -35,7 +35,7 @@ what sort order was used:
 
 from twisted.internet import defer
 
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.util.logutils import log_function
@@ -413,6 +413,7 @@ class StreamStore(SQLBaseStore):
             "get_recent_events_for_room", get_recent_events_for_room_txn
         )
 
+    @cached(num_args=0)
     def get_room_events_max_id(self):
         return self.runInteraction(
             "get_room_events_max_id",