diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index d082c26b1f..b3b2d6092d 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,7 +17,7 @@ import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.treecache import TreeCache
+from synapse.util.caches.treecache import TreeCache, popped_to_iterator
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
@@ -42,11 +42,23 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-def deferred_size(deferred):
- if deferred.called:
- return len(deferred.result)
- else:
- return 1
+class CacheEntry(object):
+ __slots__ = [
+ "deferred", "sequence", "callbacks", "invalidated"
+ ]
+
+ def __init__(self, deferred, sequence, callbacks):
+ self.deferred = deferred
+ self.sequence = sequence
+ self.callbacks = set(callbacks)
+ self.invalidated = False
+
+ def invalidate(self):
+ if not self.invalidated:
+ self.invalidated = True
+ for callback in self.callbacks:
+ callback()
+ self.callbacks.clear()
class Cache(object):
@@ -58,13 +70,16 @@ class Cache(object):
"sequence",
"thread",
"metrics",
+ "_pending_deferred_cache",
)
def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict
+ self._pending_deferred_cache = cache_type()
+
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type,
- size_callback=deferred_size if iterable else None,
+ size_callback=(lambda d: len(d.result)) if iterable else None,
)
self.name = name
@@ -84,7 +99,15 @@ class Cache(object):
)
def get(self, key, default=_CacheSentinel, callback=None):
- val = self.cache.get(key, _CacheSentinel, callback=callback)
+ callbacks = [callback] if callback else []
+ val = self._pending_deferred_cache.get(key, _CacheSentinel)
+ if val is not _CacheSentinel:
+ if val.sequence == self.sequence:
+ val.callbacks.update(callbacks)
+ self.metrics.inc_hits()
+ return val.deferred
+
+ val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@@ -96,15 +119,39 @@ class Cache(object):
else:
return default
- def update(self, sequence, key, value, callback=None):
+ def set(self, key, value, callback=None):
+ callbacks = [callback] if callback else []
self.check_thread()
- if self.sequence == sequence:
- # Only update the cache if the caches sequence number matches the
- # number that the cache had before the SELECT was started (SYN-369)
- self.prefill(key, value, callback=callback)
+ entry = CacheEntry(
+ deferred=value,
+ sequence=self.sequence,
+ callbacks=callbacks,
+ )
+
+ entry.callbacks.update(callbacks)
+
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry:
+ existing_entry.invalidate()
+
+ self._pending_deferred_cache[key] = entry
+
+ def shuffle(result):
+ if self.sequence == entry.sequence:
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry is entry:
+ self.cache.set(key, entry.deferred, entry.callbacks)
+ else:
+ entry.invalidate()
+ else:
+ entry.invalidate()
+ return result
+
+ entry.deferred.addCallback(shuffle)
def prefill(self, key, value, callback=None):
- self.cache.set(key, value, callback=callback)
+ callbacks = [callback] if callback else []
+ self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
@@ -116,6 +163,10 @@ class Cache(object):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
+ entry = self._pending_deferred_cache.pop(key, None)
+ if entry:
+ entry.invalidate()
+
self.cache.pop(key, None)
def invalidate_many(self, key):
@@ -127,6 +178,12 @@ class Cache(object):
self.sequence += 1
self.cache.del_multi(key)
+ val = self._pending_deferred_cache.pop(key, None)
+ if val is not None:
+ entry_dict, _ = val
+ for entry in popped_to_iterator(entry_dict):
+ entry.invalidate()
+
def invalidate_all(self):
self.check_thread()
self.sequence += 1
@@ -254,11 +311,6 @@ class CacheDescriptor(object):
return preserve_context_over_deferred(observer)
except KeyError:
- # Get the sequence number of the cache before reading from the
- # database so that we can tell if the cache is invalidated
- # while the SELECT is executing (SYN-369)
- sequence = cache.sequence
-
ret = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call,
@@ -272,7 +324,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
- cache.update(sequence, cache_key, ret, callback=invalidate_callback)
+ cache.set(cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@@ -370,7 +422,6 @@ class CacheListDescriptor(object):
missing.append(arg)
if missing:
- sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@@ -393,8 +444,8 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
- cache.update(
- sequence, tuple(key), observer,
+ cache.set(
+ tuple(key), observer,
callback=invalidate_callback
)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index b0ca1bb79d..cb6933c61c 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -23,7 +23,9 @@ import logging
logger = logging.getLogger(__name__)
-DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
+class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
+ def __len__(self):
+ return len(self.value)
class DictionaryCache(object):
@@ -32,7 +34,7 @@ class DictionaryCache(object):
"""
def __init__(self, name, max_entries=1000):
- self.cache = LruCache(max_size=max_entries)
+ self.cache = LruCache(max_size=max_entries, size_callback=len)
self.name = name
self.sequence = 0
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index b9ead9cbd5..2987c38a2d 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -56,6 +56,8 @@ class ExpiringCache(object):
self.iterable = iterable
+ self._size_estimate = 0
+
def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
@@ -70,9 +72,14 @@ class ExpiringCache(object):
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
+ if self.iterable:
+ self._size_estimate += len(value)
+
# Evict if there are now too many items
while self._max_len and len(self) > self._max_len:
- self._cache.popitem(last=False)
+ _key, value = self._cache.popitem(last=False)
+ if self.iterable:
+ self._size_estimate -= len(value.value)
def __getitem__(self, key):
try:
@@ -109,7 +116,9 @@ class ExpiringCache(object):
keys_to_delete.add(key)
for k in keys_to_delete:
- self._cache.pop(k)
+ value = self._cache.pop(k)
+ if self.iterable:
+ self._size_estimate -= len(value.value)
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
@@ -118,7 +127,7 @@ class ExpiringCache(object):
def __len__(self):
if self.iterable:
- return sum(len(value.value) for value in self._cache.itervalues())
+ return self._size_estimate
else:
return len(self._cache)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 00ddf38290..f1de034444 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -58,12 +58,6 @@ class LruCache(object):
lock = threading.Lock()
- def cache_len():
- if size_callback is not None:
- return sum(size_callback(node.value) for node in cache.itervalues())
- else:
- return len(cache)
-
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
@@ -78,6 +72,16 @@ class LruCache(object):
return inner
+ cached_cache_len = [0]
+ if size_callback is not None:
+ def cache_len():
+ return cached_cache_len[0]
+ else:
+ def cache_len():
+ return len(cache)
+
+ self.len = synchronized(cache_len)
+
def add_node(key, value, callbacks=set()):
prev_node = list_root
next_node = prev_node.next_node
@@ -86,6 +90,9 @@ class LruCache(object):
next_node.prev_node = node
cache[key] = node
+ if size_callback:
+ cached_cache_len[0] += size_callback(node.value)
+
def move_node_to_front(node):
prev_node = node.prev_node
next_node = node.next_node
@@ -104,23 +111,25 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
+ if size_callback:
+ cached_cache_len[0] -= size_callback(node.value)
+
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized
- def cache_get(key, default=None, callback=None):
+ def cache_get(key, default=None, callbacks=[]):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
- if callback:
- node.callbacks.add(callback)
+ node.callbacks.update(callbacks)
return node.value
else:
return default
@synchronized
- def cache_set(key, value, callback=None):
+ def cache_set(key, value, callbacks=[]):
node = cache.get(key, None)
if node is not None:
if value != node.value:
@@ -128,17 +137,16 @@ class LruCache(object):
cb()
node.callbacks.clear()
- if callback:
- node.callbacks.add(callback)
+ if size_callback:
+ cached_cache_len[0] -= size_callback(node.value)
+ cached_cache_len[0] += size_callback(value)
+
+ node.callbacks.update(callbacks)
move_node_to_front(node)
node.value = value
else:
- if callback:
- callbacks = set([callback])
- else:
- callbacks = set()
- add_node(key, value, callbacks)
+ add_node(key, value, set(callbacks))
evict()
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index c31585aea3..460e98a92d 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -65,12 +65,24 @@ class TreeCache(object):
return popped
def values(self):
- return [e.value for e in self.root.values()]
+ return list(popped_to_iterator(self.root))
def __len__(self):
return self.size
+def popped_to_iterator(d):
+ if isinstance(d, dict):
+ for value_d in d.itervalues():
+ for value in popped_to_iterator(value_d):
+ yield value
+ else:
+ if isinstance(d, _Entry):
+ yield d.value
+ else:
+ yield d
+
+
class _Entry(object):
__slots__ = ["value"]
|