diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 8d33def6c6..d976e17786 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,21 +17,20 @@ import logging
from synapse.api.errors import StoreError
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
-from synapse.util.lrucache import LruCache
+from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.caches.descriptors import Cache
import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
-from collections import namedtuple, OrderedDict
+from collections import namedtuple
-import functools
import sys
import time
import threading
-DEBUG_CACHES = False
logger = logging.getLogger(__name__)
@@ -47,159 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
-caches_by_name = {}
-cache_counter = metrics.register_cache(
- "cache",
- lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
- labels=["name"],
-)
-
-
-class Cache(object):
-
- def __init__(self, name, max_entries=1000, keylen=1, lru=False):
- if lru:
- self.cache = LruCache(max_size=max_entries)
- self.max_entries = None
- else:
- self.cache = OrderedDict()
- self.max_entries = max_entries
-
- self.name = name
- self.keylen = keylen
- self.sequence = 0
- self.thread = None
- caches_by_name[name] = self.cache
-
- def check_thread(self):
- expected_thread = self.thread
- if expected_thread is None:
- self.thread = threading.current_thread()
- else:
- if expected_thread is not threading.current_thread():
- raise ValueError(
- "Cache objects can only be accessed from the main thread"
- )
-
- def get(self, *keyargs):
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
-
- if keyargs in self.cache:
- cache_counter.inc_hits(self.name)
- return self.cache[keyargs]
-
- cache_counter.inc_misses(self.name)
- raise KeyError()
-
- def update(self, sequence, *args):
- self.check_thread()
- if self.sequence == sequence:
- # Only update the cache if the caches sequence number matches the
- # number that the cache had before the SELECT was started (SYN-369)
- self.prefill(*args)
-
- def prefill(self, *args): # because I can't *keyargs, value
- keyargs = args[:-1]
- value = args[-1]
-
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
-
- if self.max_entries is not None:
- while len(self.cache) >= self.max_entries:
- self.cache.popitem(last=False)
-
- self.cache[keyargs] = value
-
- def invalidate(self, *keyargs):
- self.check_thread()
- if len(keyargs) != self.keylen:
- raise ValueError("Expected a key to have %d items", self.keylen)
- # Increment the sequence number so that any SELECT statements that
- # raced with the INSERT don't update the cache (SYN-369)
- self.sequence += 1
- self.cache.pop(keyargs, None)
-
- def invalidate_all(self):
- self.check_thread()
- self.sequence += 1
- self.cache.clear()
-
-
-class CacheDescriptor(object):
- """ A method decorator that applies a memoizing cache around the function.
-
- The function is presumed to take zero or more arguments, which are used in
- a tuple as the key for the cache. Hits are served directly from the cache;
- misses use the function body to generate the value.
-
- The wrapped function has an additional member, a callable called
- "invalidate". This can be used to remove individual entries from the cache.
-
- The wrapped function has another additional callable, called "prefill",
- which can be used to insert values into the cache specifically, without
- calling the calculation function.
- """
- def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
- self.orig = orig
-
- self.max_entries = max_entries
- self.num_args = num_args
- self.lru = lru
-
- def __get__(self, obj, objtype=None):
- cache = Cache(
- name=self.orig.__name__,
- max_entries=self.max_entries,
- keylen=self.num_args,
- lru=self.lru,
- )
-
- @functools.wraps(self.orig)
- @defer.inlineCallbacks
- def wrapped(*keyargs):
- try:
- cached_result = cache.get(*keyargs[:self.num_args])
- if DEBUG_CACHES:
- actual_result = yield self.orig(obj, *keyargs)
- if actual_result != cached_result:
- logger.error(
- "Stale cache entry %s%r: cached: %r, actual %r",
- self.orig.__name__, keyargs,
- cached_result, actual_result,
- )
- raise ValueError("Stale cache entry")
- defer.returnValue(cached_result)
- except KeyError:
- # Get the sequence number of the cache before reading from the
- # database so that we can tell if the cache is invalidated
- # while the SELECT is executing (SYN-369)
- sequence = cache.sequence
-
- ret = yield self.orig(obj, *keyargs)
-
- cache.update(sequence, *keyargs[:self.num_args] + (ret,))
-
- defer.returnValue(ret)
-
- wrapped.invalidate = cache.invalidate
- wrapped.invalidate_all = cache.invalidate_all
- wrapped.prefill = cache.prefill
-
- obj.__dict__[self.orig.__name__] = wrapped
-
- return wrapped
-
-
-def cached(max_entries=1000, num_args=1, lru=False):
- return lambda orig: CacheDescriptor(
- orig,
- max_entries=max_entries,
- num_args=num_args,
- lru=lru
- )
-
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
@@ -321,6 +167,8 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size)
+ self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000)
+
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
@@ -329,13 +177,14 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine
- self._stream_id_gen = StreamIdGenerator()
+ self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+ self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
|