diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 39884c2afe..73eea157a4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,6 +15,7 @@
import logging
from synapse.api.errors import StoreError
+from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
@@ -27,6 +28,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
+import inspect
import sys
import time
import threading
@@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
)
+_CacheSentinel = object()
+
+
class Cache(object):
- def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+ def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
@@ -81,45 +86,44 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
- def get(self, *keyargs):
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
-
- if keyargs in self.cache:
+ def get(self, key, default=_CacheSentinel):
+ val = self.cache.get(key, _CacheSentinel)
+ if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
- return self.cache[keyargs]
+ return val
cache_counter.inc_misses(self.name)
- raise KeyError()
- def update(self, sequence, *args):
+ if default is _CacheSentinel:
+ raise KeyError()
+ else:
+ return default
+
+ def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
- self.prefill(*args)
-
- def prefill(self, *args): # because I can't *keyargs, value
- keyargs = args[:-1]
- value = args[-1]
-
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ self.prefill(key, value)
+ def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
- self.cache[keyargs] = value
+ self.cache[key] = value
- def invalidate(self, *keyargs):
+ def invalidate(self, key):
self.check_thread()
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
+ if not isinstance(key, tuple):
+ raise TypeError(
+ "The cache key must be a tuple not %r" % (type(key),)
+ )
+
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
- self.cache.pop(keyargs, None)
+ self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
@@ -127,9 +131,12 @@ class Cache(object):
self.cache.clear()
-def cached(max_entries=1000, num_args=1, lru=False):
+class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
+ This caches deferreds, rather than the results themselves. Deferreds that
+ fail are removed from the cache.
+
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
@@ -141,47 +148,108 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
- def wrap(orig):
- cache = Cache(
- name=orig.__name__,
- max_entries=max_entries,
- keylen=num_args,
- lru=lru,
+ def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+ inlineCallbacks=False):
+ self.orig = orig
+
+ if inlineCallbacks:
+ self.function_to_call = defer.inlineCallbacks(orig)
+ else:
+ self.function_to_call = orig
+
+ self.max_entries = max_entries
+ self.num_args = num_args
+ self.lru = lru
+
+ self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+
+ if len(self.arg_names) < self.num_args:
+ raise Exception(
+ "Not enough explicit positional arguments to key off of for %r."
+ " (@cached cannot key off of *args or **kwars)"
+ % (orig.__name__,)
+ )
+
+ self.cache = Cache(
+ name=self.orig.__name__,
+ max_entries=self.max_entries,
+ keylen=self.num_args,
+ lru=self.lru,
)
- @functools.wraps(orig)
- @defer.inlineCallbacks
- def wrapped(self, *keyargs):
+ def __get__(self, obj, objtype=None):
+
+ @functools.wraps(self.orig)
+ def wrapped(*args, **kwargs):
+ arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
+ cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
- cached_result = cache.get(*keyargs)
+ cached_result_d = self.cache.get(cache_key)
+
+ observer = cached_result_d.observe()
if DEBUG_CACHES:
- actual_result = yield orig(self, *keyargs)
- if actual_result != cached_result:
- logger.error(
- "Stale cache entry %s%r: cached: %r, actual %r",
- orig.__name__, keyargs,
- cached_result, actual_result,
- )
- raise ValueError("Stale cache entry")
- defer.returnValue(cached_result)
+ @defer.inlineCallbacks
+ def check_result(cached_result):
+ actual_result = yield self.function_to_call(obj, *args, **kwargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ self.orig.__name__, cache_key,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
+ observer.addCallback(check_result)
+
+ return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
- sequence = cache.sequence
+ sequence = self.cache.sequence
+
+ ret = defer.maybeDeferred(
+ self.function_to_call,
+ obj, *args, **kwargs
+ )
+
+ def onErr(f):
+ self.cache.invalidate(cache_key)
+ return f
+
+ ret.addErrback(onErr)
+
+ ret = ObservableDeferred(ret, consumeErrors=True)
+ self.cache.update(sequence, cache_key, ret)
- ret = yield orig(self, *keyargs)
+ return ret.observe()
- cache.update(sequence, *keyargs + (ret,))
+ wrapped.invalidate = self.cache.invalidate
+ wrapped.invalidate_all = self.cache.invalidate_all
+ wrapped.prefill = self.cache.prefill
- defer.returnValue(ret)
+ obj.__dict__[self.orig.__name__] = wrapped
- wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
- wrapped.prefill = cache.prefill
return wrapped
- return wrap
+
+def cached(max_entries=1000, num_args=1, lru=True):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru
+ )
+
+
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+ return lambda orig: CacheDescriptor(
+ orig,
+ max_entries=max_entries,
+ num_args=num_args,
+ lru=lru,
+ inlineCallbacks=True,
+ )
class LoggingTransaction(object):
@@ -312,13 +380,14 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
- self._stream_id_gen = StreamIdGenerator()
+ self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+ self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
|