diff options
Diffstat (limited to 'synapse/util/caches')
-rw-r--r-- | synapse/util/caches/__init__.py | 20 | ||||
-rw-r--r-- | synapse/util/caches/descriptors.py | 93 | ||||
-rw-r--r-- | synapse/util/caches/dictionary_cache.py | 8 | ||||
-rw-r--r-- | synapse/util/caches/expiringcache.py | 8 | ||||
-rw-r--r-- | synapse/util/caches/response_cache.py | 46 | ||||
-rw-r--r-- | synapse/util/caches/stream_change_cache.py | 16 |
6 files changed, 137 insertions, 54 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index d53569ca49..ebd715c5dc 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -24,11 +24,21 @@ DEBUG_CACHES = False metrics = synapse.metrics.get_metrics_for("synapse.util.caches") caches_by_name = {} -cache_counter = metrics.register_cache( - "cache", - lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, - labels=["name"], -) +# cache_counter = metrics.register_cache( +# "cache", +# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, +# labels=["name"], +# ) + + +def register_cache(name, cache): + caches_by_name[name] = cache + return metrics.register_cache( + "cache", + lambda: len(cache), + name, + ) + _string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) caches_by_name["string_cache"] = _string_cache diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 35544b19fd..f31dfb22b7 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -22,7 +22,7 @@ from synapse.util.logcontext import ( PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn ) -from . import caches_by_name, DEBUG_CACHES, cache_counter +from . import DEBUG_CACHES, register_cache from twisted.internet import defer @@ -33,6 +33,7 @@ import functools import inspect import threading + logger = logging.getLogger(__name__) @@ -43,6 +44,15 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) class Cache(object): + __slots__ = ( + "cache", + "max_entries", + "name", + "keylen", + "sequence", + "thread", + "metrics", + ) def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False): if lru: @@ -59,7 +69,7 @@ class Cache(object): self.keylen = keylen self.sequence = 0 self.thread = None - caches_by_name[name] = self.cache + self.metrics = register_cache(name, self.cache) def check_thread(self): expected_thread = self.thread @@ -74,10 +84,10 @@ class Cache(object): def get(self, key, default=_CacheSentinel): val = self.cache.get(key, _CacheSentinel) if val is not _CacheSentinel: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return val - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() if default is _CacheSentinel: raise KeyError() @@ -167,7 +177,8 @@ class CacheDescriptor(object): % (orig.__name__,) ) - self.cache = Cache( + def __get__(self, obj, objtype=None): + cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, @@ -175,14 +186,12 @@ class CacheDescriptor(object): tree=self.tree, ) - def __get__(self, obj, objtype=None): - @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) try: - cached_result_d = self.cache.get(cache_key) + cached_result_d = cache.get(cache_key) observer = cached_result_d.observe() if DEBUG_CACHES: @@ -204,7 +213,7 @@ class CacheDescriptor(object): # Get the sequence number of the cache before reading from the # database so that we can tell if the cache is invalidated # while the SELECT is executing (SYN-369) - sequence = self.cache.sequence + sequence = cache.sequence ret = defer.maybeDeferred( preserve_context_over_fn, @@ -213,20 +222,21 @@ class CacheDescriptor(object): ) def onErr(f): - self.cache.invalidate(cache_key) + cache.invalidate(cache_key) return f ret.addErrback(onErr) ret = ObservableDeferred(ret, consumeErrors=True) - self.cache.update(sequence, cache_key, ret) + cache.update(sequence, cache_key, ret) return preserve_context_over_deferred(ret.observe()) - wrapped.invalidate = self.cache.invalidate - wrapped.invalidate_all = self.cache.invalidate_all - wrapped.invalidate_many = self.cache.invalidate_many - wrapped.prefill = self.cache.prefill + wrapped.invalidate = cache.invalidate + wrapped.invalidate_all = cache.invalidate_all + wrapped.invalidate_many = cache.invalidate_many + wrapped.prefill = cache.prefill + wrapped.cache = cache obj.__dict__[self.orig.__name__] = wrapped @@ -240,11 +250,12 @@ class CacheListDescriptor(object): the list of missing keys to the wrapped fucntion. """ - def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): + def __init__(self, orig, cached_method_name, list_name, num_args=1, + inlineCallbacks=False): """ Args: orig (function) - cache (Cache) + method_name (str); The name of the chached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int) inlineCallbacks (bool): Whether orig is a generator that should @@ -263,7 +274,7 @@ class CacheListDescriptor(object): self.arg_names = inspect.getargspec(orig).args[1:num_args + 1] self.list_pos = self.arg_names.index(self.list_name) - self.cache = cache + self.cached_method_name = cached_method_name self.sentinel = object() @@ -277,11 +288,13 @@ class CacheListDescriptor(object): if self.list_name not in self.arg_names: raise Exception( "Couldn't see arguments %r for %r." - % (self.list_name, cache.name,) + % (self.list_name, cached_method_name,) ) def __get__(self, obj, objtype=None): + cache = getattr(obj, self.cached_method_name).cache + @functools.wraps(self.orig) def wrapped(*args, **kwargs): arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) @@ -290,21 +303,26 @@ class CacheListDescriptor(object): # cached is a dict arg -> deferred, where deferred results in a # 2-tuple (`arg`, `result`) - cached = {} + results = {} + cached_defers = {} missing = [] for arg in list_args: key = list(keyargs) key[self.list_pos] = arg try: - res = self.cache.get(tuple(key)).observe() - res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res + res = cache.get(tuple(key)) + if not res.has_succeeded(): + res = res.observe() + res.addCallback(lambda r, arg: (arg, r), arg) + cached_defers[arg] = res + else: + results[arg] = res.get_result() except KeyError: missing.append(arg) if missing: - sequence = self.cache.sequence + sequence = cache.sequence args_to_call = dict(arg_dict) args_to_call[self.list_name] = missing @@ -327,22 +345,31 @@ class CacheListDescriptor(object): key = list(keyargs) key[self.list_pos] = arg - self.cache.update(sequence, tuple(key), observer) + cache.update(sequence, tuple(key), observer) def invalidate(f, key): - self.cache.invalidate(key) + cache.invalidate(key) return f observer.addErrback(invalidate, tuple(key)) res = observer.observe() res.addCallback(lambda r, arg: (arg, r), arg) - cached[arg] = res + cached_defers[arg] = res + + if cached_defers: + def update_results_dict(res): + results.update(res) + return results - return preserve_context_over_deferred(defer.gatherResults( - cached.values(), - consumeErrors=True, - ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))) + return preserve_context_over_deferred(defer.gatherResults( + cached_defers.values(), + consumeErrors=True, + ).addCallback(update_results_dict).addErrback( + unwrapFirstError + )) + else: + return results obj.__dict__[self.orig.__name__] = wrapped @@ -370,7 +397,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False): ) -def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): +def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False): """Creates a descriptor that wraps a function in a `CacheListDescriptor`. Used to do batch lookups for an already created cache. A single argument @@ -400,7 +427,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): """ return lambda orig: CacheListDescriptor( orig, - cache=cache, + cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, inlineCallbacks=inlineCallbacks, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index f92d80542b..b0ca1bb79d 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -15,7 +15,7 @@ from synapse.util.caches.lrucache import LruCache from collections import namedtuple -from . import caches_by_name, cache_counter +from . import register_cache import threading import logging @@ -43,7 +43,7 @@ class DictionaryCache(object): __slots__ = [] self.sentinel = Sentinel() - caches_by_name[name] = self.cache + self.metrics = register_cache(name, self.cache) def check_thread(self): expected_thread = self.thread @@ -58,7 +58,7 @@ class DictionaryCache(object): def get(self, key, dict_keys=None): entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() if dict_keys is None: return DictionaryEntry(entry.full, dict(entry.value)) @@ -69,7 +69,7 @@ class DictionaryCache(object): if k in entry.value }) - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return DictionaryEntry(False, {}) def invalidate(self, key): diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 2b68c1ac93..080388958f 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import cache_counter, caches_by_name +from synapse.util.caches import register_cache import logging @@ -49,7 +49,7 @@ class ExpiringCache(object): self._cache = {} - caches_by_name[cache_name] = self._cache + self.metrics = register_cache(cache_name, self._cache) def start(self): if not self._expiry_ms: @@ -78,9 +78,9 @@ class ExpiringCache(object): def __getitem__(self, key): try: entry = self._cache[key] - cache_counter.inc_hits(self._cache_name) + self.metrics.inc_hits() except KeyError: - cache_counter.inc_misses(self._cache_name) + self.metrics.inc_misses() raise if self._reset_expiry_on_get: diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py new file mode 100644 index 0000000000..36686b479e --- /dev/null +++ b/synapse/util/caches/response_cache.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.async import ObservableDeferred + + +class ResponseCache(object): + """ + This caches a deferred response. Until the deferred completes it will be + returned from the cache. This means that if the client retries the request + while the response is still being computed, that original response will be + used rather than trying to compute a new response. + """ + + def __init__(self): + self.pending_result_cache = {} # Requests that haven't finished yet. + + def get(self, key): + result = self.pending_result_cache.get(key) + if result is not None: + return result.observe() + else: + return None + + def set(self, key, deferred): + result = ObservableDeferred(deferred, consumeErrors=True) + self.pending_result_cache[key] = result + + def remove(r): + self.pending_result_cache.pop(key, None) + return r + + result.addBoth(remove) + return result.observe() diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index ea8a74ca69..3c051dabc4 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.caches import cache_counter, caches_by_name +from synapse.util.caches import register_cache from blist import sorteddict @@ -42,7 +42,7 @@ class StreamChangeCache(object): self._cache = sorteddict() self._earliest_known_stream_pos = current_stream_pos self.name = name - caches_by_name[self.name] = self._cache + self.metrics = register_cache(self.name, self._cache) for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) @@ -53,19 +53,19 @@ class StreamChangeCache(object): assert type(stream_pos) is int if stream_pos < self._earliest_known_stream_pos: - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return True latest_entity_change_pos = self._entity_to_key.get(entity, None) if latest_entity_change_pos is None: - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return False if stream_pos < latest_entity_change_pos: - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return True - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() return False def get_entities_changed(self, entities, stream_pos): @@ -82,10 +82,10 @@ class StreamChangeCache(object): self._cache[k] for k in keys[i:] ).intersection(entities) - cache_counter.inc_hits(self.name) + self.metrics.inc_hits() else: result = entities - cache_counter.inc_misses(self.name) + self.metrics.inc_misses() return result |