diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 165 |
1 files changed, 7 insertions, 158 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8d33def6c6..d976e17786 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,21 +17,20 @@ import logging from synapse.api.errors import StoreError from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn, LoggingContext -from synapse.util.lrucache import LruCache +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.caches.descriptors import Cache import synapse.metrics from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer -from collections import namedtuple, OrderedDict +from collections import namedtuple -import functools import sys import time import threading -DEBUG_CACHES = False logger = logging.getLogger(__name__) @@ -47,159 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) -caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) - - -class Cache(object): - - def __init__(self, name, max_entries=1000, keylen=1, lru=False): - if lru: - self.cache = LruCache(max_size=max_entries) - self.max_entries = None - else: - self.cache = OrderedDict() - self.max_entries = max_entries - - self.name = name - self.keylen = keylen - self.sequence = 0 - self.thread = None - caches_by_name[name] = self.cache - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get(self, *keyargs): - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - - if keyargs in self.cache: - cache_counter.inc_hits(self.name) - return self.cache[keyargs] - - cache_counter.inc_misses(self.name) - raise KeyError() - - def update(self, sequence, *args): - self.check_thread() - if self.sequence == sequence: - # Only update the cache if the caches sequence number matches the - # number that the cache had before the SELECT was started (SYN-369) - self.prefill(*args) - - def prefill(self, *args): # because I can't *keyargs, value - keyargs = args[:-1] - value = args[-1] - - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - - if self.max_entries is not None: - while len(self.cache) >= self.max_entries: - self.cache.popitem(last=False) - - self.cache[keyargs] = value - - def invalidate(self, *keyargs): - self.check_thread() - if len(keyargs) != self.keylen: - raise ValueError("Expected a key to have %d items", self.keylen) - # Increment the sequence number so that any SELECT statements that - # raced with the INSERT don't update the cache (SYN-369) - self.sequence += 1 - self.cache.pop(keyargs, None) - - def invalidate_all(self): - self.check_thread() - self.sequence += 1 - self.cache.clear() - - -class CacheDescriptor(object): - """ A method decorator that applies a memoizing cache around the function. - - The function is presumed to take zero or more arguments, which are used in - a tuple as the key for the cache. Hits are served directly from the cache; - misses use the function body to generate the value. - - The wrapped function has an additional member, a callable called - "invalidate". This can be used to remove individual entries from the cache. - - The wrapped function has another additional callable, called "prefill", - which can be used to insert values into the cache specifically, without - calling the calculation function. - """ - def __init__(self, orig, max_entries=1000, num_args=1, lru=False): - self.orig = orig - - self.max_entries = max_entries - self.num_args = num_args - self.lru = lru - - def __get__(self, obj, objtype=None): - cache = Cache( - name=self.orig.__name__, - max_entries=self.max_entries, - keylen=self.num_args, - lru=self.lru, - ) - - @functools.wraps(self.orig) - @defer.inlineCallbacks - def wrapped(*keyargs): - try: - cached_result = cache.get(*keyargs[:self.num_args]) - if DEBUG_CACHES: - actual_result = yield self.orig(obj, *keyargs) - if actual_result != cached_result: - logger.error( - "Stale cache entry %s%r: cached: %r, actual %r", - self.orig.__name__, keyargs, - cached_result, actual_result, - ) - raise ValueError("Stale cache entry") - defer.returnValue(cached_result) - except KeyError: - # Get the sequence number of the cache before reading from the - # database so that we can tell if the cache is invalidated - # while the SELECT is executing (SYN-369) - sequence = cache.sequence - - ret = yield self.orig(obj, *keyargs) - - cache.update(sequence, *keyargs[:self.num_args] + (ret,)) - - defer.returnValue(ret) - - wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all - wrapped.prefill = cache.prefill - - obj.__dict__[self.orig.__name__] = wrapped - - return wrapped - - -def cached(max_entries=1000, num_args=1, lru=False): - return lambda orig: CacheDescriptor( - orig, - max_entries=max_entries, - num_args=num_args, - lru=lru - ) - class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object @@ -321,6 +167,8 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000) + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -329,13 +177,14 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine - self._stream_id_gen = StreamIdGenerator() + self._stream_id_gen = StreamIdGenerator("events", "stream_ordering") self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id") def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() |