diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 8dba61d49f..a9ea97fd46 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,7 +17,7 @@ import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache
+from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
@@ -42,6 +42,25 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+class CacheEntry(object):
+ __slots__ = [
+ "deferred", "sequence", "callbacks", "invalidated"
+ ]
+
+ def __init__(self, deferred, sequence, callbacks):
+ self.deferred = deferred
+ self.sequence = sequence
+ self.callbacks = set(callbacks)
+ self.invalidated = False
+
+ def invalidate(self):
+ if not self.invalidated:
+ self.invalidated = True
+ for callback in self.callbacks:
+ callback()
+ self.callbacks.clear()
+
+
class Cache(object):
__slots__ = (
"cache",
@@ -51,12 +70,16 @@ class Cache(object):
"sequence",
"thread",
"metrics",
+ "_pending_deferred_cache",
)
- def __init__(self, name, max_entries=1000, keylen=1, tree=False):
+ def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict
+ self._pending_deferred_cache = cache_type()
+
self.cache = LruCache(
- max_size=max_entries, keylen=keylen, cache_type=cache_type
+ max_size=max_entries, keylen=keylen, cache_type=cache_type,
+ size_callback=(lambda d: len(d.result)) if iterable else None,
)
self.name = name
@@ -76,7 +99,15 @@ class Cache(object):
)
def get(self, key, default=_CacheSentinel, callback=None):
- val = self.cache.get(key, _CacheSentinel, callback=callback)
+ callbacks = [callback] if callback else []
+ val = self._pending_deferred_cache.get(key, _CacheSentinel)
+ if val is not _CacheSentinel:
+ if val.sequence == self.sequence:
+ val.callbacks.update(callbacks)
+ self.metrics.inc_hits()
+ return val.deferred
+
+ val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@@ -88,15 +119,39 @@ class Cache(object):
else:
return default
- def update(self, sequence, key, value, callback=None):
+ def set(self, key, value, callback=None):
+ callbacks = [callback] if callback else []
self.check_thread()
- if self.sequence == sequence:
- # Only update the cache if the caches sequence number matches the
- # number that the cache had before the SELECT was started (SYN-369)
- self.prefill(key, value, callback=callback)
+ entry = CacheEntry(
+ deferred=value,
+ sequence=self.sequence,
+ callbacks=callbacks,
+ )
+
+ entry.callbacks.update(callbacks)
+
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry:
+ existing_entry.invalidate()
+
+ self._pending_deferred_cache[key] = entry
+
+ def shuffle(result):
+ if self.sequence == entry.sequence:
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry is entry:
+ self.cache.set(key, entry.deferred, entry.callbacks)
+ else:
+ entry.invalidate()
+ else:
+ entry.invalidate()
+ return result
+
+ entry.deferred.addCallback(shuffle)
def prefill(self, key, value, callback=None):
- self.cache.set(key, value, callback=callback)
+ callbacks = [callback] if callback else []
+ self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
@@ -108,6 +163,10 @@ class Cache(object):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
+ entry = self._pending_deferred_cache.pop(key, None)
+ if entry:
+ entry.invalidate()
+
self.cache.pop(key, None)
def invalidate_many(self, key):
@@ -119,6 +178,12 @@ class Cache(object):
self.sequence += 1
self.cache.del_multi(key)
+ val = self._pending_deferred_cache.pop(key, None)
+ if val is not None:
+ entry_dict, _ = val
+ for entry in iterate_tree_cache_entry(entry_dict):
+ entry.invalidate()
+
def invalidate_all(self):
self.check_thread()
self.sequence += 1
@@ -155,7 +220,7 @@ class CacheDescriptor(object):
"""
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
- inlineCallbacks=False, cache_context=False):
+ inlineCallbacks=False, cache_context=False, iterable=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@@ -169,6 +234,8 @@ class CacheDescriptor(object):
self.num_args = num_args
self.tree = tree
+ self.iterable = iterable
+
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
@@ -203,6 +270,7 @@ class CacheDescriptor(object):
max_entries=self.max_entries,
keylen=self.num_args,
tree=self.tree,
+ iterable=self.iterable,
)
@functools.wraps(self.orig)
@@ -243,11 +311,6 @@ class CacheDescriptor(object):
return preserve_context_over_deferred(observer)
except KeyError:
- # Get the sequence number of the cache before reading from the
- # database so that we can tell if the cache is invalidated
- # while the SELECT is executing (SYN-369)
- sequence = cache.sequence
-
ret = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call,
@@ -261,7 +324,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
- cache.update(sequence, cache_key, ret, callback=invalidate_callback)
+ cache.set(cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@@ -359,7 +422,6 @@ class CacheListDescriptor(object):
missing.append(arg)
if missing:
- sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@@ -382,8 +444,8 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
- cache.update(
- sequence, tuple(key), observer,
+ cache.set(
+ tuple(key), observer,
callback=invalidate_callback
)
@@ -421,17 +483,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key)
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+ iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
tree=tree,
cache_context=cache_context,
+ iterable=iterable,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
+ iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -439,6 +504,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
tree=tree,
inlineCallbacks=True,
cache_context=cache_context,
+ iterable=iterable,
)
|