summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py335
1 files changed, 3 insertions, 332 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e5441aafb2..1444767a52 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,27 +15,22 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
-from synapse.util.lrucache import LruCache
-from synapse.util.dictionary_cache import DictionaryCache
+from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.caches.descriptors import Cache
 import synapse.metrics
 
 from util.id_generators import IdGenerator, StreamIdGenerator
 
 from twisted.internet import defer
 
-from collections import namedtuple, OrderedDict
+from collections import namedtuple
 
-import functools
-import inspect
 import sys
 import time
 import threading
 
-DEBUG_CACHES = False
 
 logger = logging.getLogger(__name__)
 
@@ -51,330 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
 sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
 sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
 
-caches_by_name = {}
-cache_counter = metrics.register_cache(
-    "cache",
-    lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
-    labels=["name"],
-)
-
-
-_CacheSentinel = object()
-
-
-class Cache(object):
-
-    def __init__(self, name, max_entries=1000, keylen=1, lru=True):
-        if lru:
-            self.cache = LruCache(max_size=max_entries)
-            self.max_entries = None
-        else:
-            self.cache = OrderedDict()
-            self.max_entries = max_entries
-
-        self.name = name
-        self.keylen = keylen
-        self.sequence = 0
-        self.thread = None
-        caches_by_name[name] = self.cache
-
-    def check_thread(self):
-        expected_thread = self.thread
-        if expected_thread is None:
-            self.thread = threading.current_thread()
-        else:
-            if expected_thread is not threading.current_thread():
-                raise ValueError(
-                    "Cache objects can only be accessed from the main thread"
-                )
-
-    def get(self, key, default=_CacheSentinel):
-        val = self.cache.get(key, _CacheSentinel)
-        if val is not _CacheSentinel:
-            cache_counter.inc_hits(self.name)
-            return val
-
-        cache_counter.inc_misses(self.name)
-
-        if default is _CacheSentinel:
-            raise KeyError()
-        else:
-            return default
-
-    def update(self, sequence, key, value):
-        self.check_thread()
-        if self.sequence == sequence:
-            # Only update the cache if the caches sequence number matches the
-            # number that the cache had before the SELECT was started (SYN-369)
-            self.prefill(key, value)
-
-    def prefill(self, key, value):
-        if self.max_entries is not None:
-            while len(self.cache) >= self.max_entries:
-                self.cache.popitem(last=False)
-
-        self.cache[key] = value
-
-    def invalidate(self, key):
-        self.check_thread()
-        if not isinstance(key, tuple):
-            raise ValueError("keyargs must be a tuple.")
-
-        # Increment the sequence number so that any SELECT statements that
-        # raced with the INSERT don't update the cache (SYN-369)
-        self.sequence += 1
-        self.cache.pop(key, None)
-
-    def invalidate_all(self):
-        self.check_thread()
-        self.sequence += 1
-        self.cache.clear()
-
-
-class CacheDescriptor(object):
-    """ A method decorator that applies a memoizing cache around the function.
-
-    This caches deferreds, rather than the results themselves. Deferreds that
-    fail are removed from the cache.
-
-    The function is presumed to take zero or more arguments, which are used in
-    a tuple as the key for the cache. Hits are served directly from the cache;
-    misses use the function body to generate the value.
-
-    The wrapped function has an additional member, a callable called
-    "invalidate". This can be used to remove individual entries from the cache.
-
-    The wrapped function has another additional callable, called "prefill",
-    which can be used to insert values into the cache specifically, without
-    calling the calculation function.
-    """
-    def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
-                 inlineCallbacks=False):
-        self.orig = orig
-
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
-        self.max_entries = max_entries
-        self.num_args = num_args
-        self.lru = lru
-
-        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
-
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
-                % (orig.__name__,)
-            )
-
-        self.cache = Cache(
-            name=self.orig.__name__,
-            max_entries=self.max_entries,
-            keylen=self.num_args,
-            lru=self.lru,
-        )
-
-    def __get__(self, obj, objtype=None):
-
-        @functools.wraps(self.orig)
-        def wrapped(*args, **kwargs):
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
-            try:
-                cached_result_d = self.cache.get(cache_key)
-
-                observer = cached_result_d.observe()
-                if DEBUG_CACHES:
-                    @defer.inlineCallbacks
-                    def check_result(cached_result):
-                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
-                        if actual_result != cached_result:
-                            logger.error(
-                                "Stale cache entry %s%r: cached: %r, actual %r",
-                                self.orig.__name__, cache_key,
-                                cached_result, actual_result,
-                            )
-                            raise ValueError("Stale cache entry")
-                        defer.returnValue(cached_result)
-                    observer.addCallback(check_result)
-
-                return observer
-            except KeyError:
-                # Get the sequence number of the cache before reading from the
-                # database so that we can tell if the cache is invalidated
-                # while the SELECT is executing (SYN-369)
-                sequence = self.cache.sequence
-
-                ret = defer.maybeDeferred(
-                    self.function_to_call,
-                    obj, *args, **kwargs
-                )
-
-                def onErr(f):
-                    self.cache.invalidate(cache_key)
-                    return f
-
-                ret.addErrback(onErr)
-
-                ret = ObservableDeferred(ret, consumeErrors=True)
-                self.cache.update(sequence, cache_key, ret)
-
-                return ret.observe()
-
-        wrapped.invalidate = self.cache.invalidate
-        wrapped.invalidate_all = self.cache.invalidate_all
-        wrapped.prefill = self.cache.prefill
-
-        obj.__dict__[self.orig.__name__] = wrapped
-
-        return wrapped
-
-
-class CacheListDescriptor(object):
-    """Wraps an existing cache to support bulk fetching of keys.
-
-    Given a list of keys it looks in the cache to find any hits, then passes
-    the list of missing keys to the wrapped fucntion.
-    """
-
-    def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
-        """
-        Args:
-            orig (function)
-            cache (Cache)
-            list_name (str): Name of the argument which is the bulk lookup list
-            num_args (int)
-            inlineCallbacks (bool): Whether orig is a generator that should
-                be wrapped by defer.inlineCallbacks
-        """
-        self.orig = orig
-
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
-        self.num_args = num_args
-        self.list_name = list_name
-
-        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
-        self.list_pos = self.arg_names.index(self.list_name)
-
-        self.cache = cache
-
-        self.sentinel = object()
-
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
-                % (orig.__name__,)
-            )
-
-        if self.list_name not in self.arg_names:
-            raise Exception(
-                "Couldn't see arguments %r for %r."
-                % (self.list_name, cache.name,)
-            )
-
-    def __get__(self, obj, objtype=None):
-
-        @functools.wraps(self.orig)
-        def wrapped(*args, **kwargs):
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
-            list_args = arg_dict[self.list_name]
-
-            # cached is a dict arg -> deferred, where deferred results in a
-            # 2-tuple (`arg`, `result`)
-            cached = {}
-            missing = []
-            for arg in list_args:
-                key = list(keyargs)
-                key[self.list_pos] = arg
-
-                try:
-                    res = self.cache.get(tuple(key)).observe()
-                    res.addCallback(lambda r, arg: (arg, r), arg)
-                    cached[arg] = res
-                except KeyError:
-                    missing.append(arg)
-
-            if missing:
-                sequence = self.cache.sequence
-                args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
-
-                ret_d = defer.maybeDeferred(
-                    self.function_to_call,
-                    **args_to_call
-                )
-
-                ret_d = ObservableDeferred(ret_d)
-
-                # We need to create deferreds for each arg in the list so that
-                # we can insert the new deferred into the cache.
-                for arg in missing:
-                    observer = ret_d.observe()
-                    observer.addCallback(lambda r, arg: r[arg], arg)
-
-                    observer = ObservableDeferred(observer)
-
-                    key = list(keyargs)
-                    key[self.list_pos] = arg
-                    self.cache.update(sequence, tuple(key), observer)
-
-                    def invalidate(f, key):
-                        self.cache.invalidate(key)
-                        return f
-                    observer.addErrback(invalidate, tuple(key))
-
-                    res = observer.observe()
-                    res.addCallback(lambda r, arg: (arg, r), arg)
-
-                    cached[arg] = res
-
-            return defer.gatherResults(
-                cached.values(),
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
-
-        obj.__dict__[self.orig.__name__] = wrapped
-
-        return wrapped
-
-
-def cached(max_entries=1000, num_args=1, lru=True):
-    return lambda orig: CacheDescriptor(
-        orig,
-        max_entries=max_entries,
-        num_args=num_args,
-        lru=lru
-    )
-
-
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
-    return lambda orig: CacheDescriptor(
-        orig,
-        max_entries=max_entries,
-        num_args=num_args,
-        lru=lru,
-        inlineCallbacks=True,
-    )
-
-
-def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
-    return lambda orig: CacheListDescriptor(
-        orig,
-        cache=cache,
-        list_name=list_name,
-        num_args=num_args,
-        inlineCallbacks=inlineCallbacks,
-    )
-
 
 class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object