diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index da0e06a468..1a14904194 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 362944bc51..277854ccbc 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,6 +17,10 @@ import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
from . import caches_by_name, DEBUG_CACHES, cache_counter
@@ -36,9 +40,12 @@ _CacheSentinel = object()
class Cache(object):
- def __init__(self, name, max_entries=1000, keylen=1, lru=True):
+ def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
if lru:
- self.cache = LruCache(max_size=max_entries)
+ cache_type = TreeCache if tree else dict
+ self.cache = LruCache(
+ max_size=max_entries, keylen=keylen, cache_type=cache_type
+ )
self.max_entries = None
else:
self.cache = OrderedDict()
@@ -99,6 +106,15 @@ class Cache(object):
self.sequence += 1
self.cache.pop(key, None)
+ def invalidate_many(self, key):
+ self.check_thread()
+ if not isinstance(key, tuple):
+ raise TypeError(
+ "The cache key must be a tuple not %r" % (type(key),)
+ )
+ self.sequence += 1
+ self.cache.del_multi(key)
+
def invalidate_all(self):
self.check_thread()
self.sequence += 1
@@ -122,7 +138,7 @@ class CacheDescriptor(object):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
- def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
+ def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
inlineCallbacks=False):
self.orig = orig
@@ -134,8 +150,9 @@ class CacheDescriptor(object):
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
+ self.tree = tree
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
if len(self.arg_names) < self.num_args:
raise Exception(
@@ -149,6 +166,7 @@ class CacheDescriptor(object):
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
+ tree=self.tree,
)
def __get__(self, obj, objtype=None):
@@ -175,7 +193,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return observer
+ return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
@@ -183,6 +201,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence
ret = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
obj, *args, **kwargs
)
@@ -196,10 +215,11 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
- return ret.observe()
+ return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
+ wrapped.invalidate_many = self.cache.invalidate_many
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
@@ -234,7 +254,7 @@ class CacheListDescriptor(object):
self.num_args = num_args
self.list_name = list_name
- self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+ self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
@@ -283,6 +303,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
**args_to_call
)
@@ -292,7 +313,8 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- observer = ret_d.observe()
+ with PreserveLoggingContext():
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -311,31 +333,33 @@ class CacheListDescriptor(object):
cached[arg] = res
- return defer.gatherResults(
+ return preserve_context_over_deferred(defer.gatherResults(
cached.values(),
consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+ ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
-def cached(max_entries=1000, num_args=1, lru=True):
+def cached(max_entries=1000, num_args=1, lru=True, tree=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
- lru=lru
+ lru=lru,
+ tree=tree,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
+ tree=tree,
inlineCallbacks=True,
)
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index e69adf62fe..f92d80542b 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 06d1eea01b..62cae99649 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -55,7 +55,7 @@ class ExpiringCache(object):
def f():
self._prune_cache()
- self._clock.looping_call(f, self._expiry_ms/2)
+ self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value):
now = self._clock.time_msec()
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cacd7e45fa..f7423f2fab 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,11 +17,27 @@
from functools import wraps
import threading
+from synapse.util.caches.treecache import TreeCache
+
+
+def enumerate_leaves(node, depth):
+ if depth == 0:
+ yield node
+ else:
+ for n in node.values():
+ for m in enumerate_leaves(n, depth - 1):
+ yield m
+
class LruCache(object):
- """Least-recently-used cache."""
- def __init__(self, max_size):
- cache = {}
+ """
+ Least-recently-used cache.
+ Supports del_multi only if cache_type=TreeCache
+ If cache_type=TreeCache, all keys must be tuples.
+ """
+ def __init__(self, max_size, keylen=1, cache_type=dict):
+ cache = cache_type()
+ self.cache = cache # Used for introspection.
list_root = []
list_root[:] = [list_root, list_root, None, None]
@@ -62,7 +78,6 @@ class LruCache(object):
next_node = node[NEXT]
prev_node[NEXT] = next_node
next_node[PREV] = prev_node
- cache.pop(node[KEY], None)
@synchronized
def cache_get(key, default=None):
@@ -82,7 +97,9 @@ class LruCache(object):
else:
add_node(key, value)
if len(cache) > max_size:
- delete_node(list_root[PREV])
+ todelete = list_root[PREV]
+ delete_node(todelete)
+ cache.pop(todelete[KEY], None)
@synchronized
def cache_set_default(key, value):
@@ -92,7 +109,9 @@ class LruCache(object):
else:
add_node(key, value)
if len(cache) > max_size:
- delete_node(list_root[PREV])
+ todelete = list_root[PREV]
+ delete_node(todelete)
+ cache.pop(todelete[KEY], None)
return value
@synchronized
@@ -100,11 +119,23 @@ class LruCache(object):
node = cache.get(key, None)
if node:
delete_node(node)
+ cache.pop(node[KEY], None)
return node[VALUE]
else:
return default
@synchronized
+ def cache_del_multi(key):
+ """
+ This will only work if constructed with cache_type=TreeCache
+ """
+ popped = cache.pop(key)
+ if popped is None:
+ return
+ for leaf in enumerate_leaves(popped, keylen - len(key)):
+ delete_node(leaf)
+
+ @synchronized
def cache_clear():
list_root[NEXT] = list_root
list_root[PREV] = list_root
@@ -123,6 +154,8 @@ class LruCache(object):
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
+ if cache_type is TreeCache:
+ self.del_multi = cache_del_multi
self.len = cache_len
self.contains = cache_contains
self.clear = cache_clear
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index 09f00afbc5..d03678b8c8 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
+# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache.
self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None)
+ return r
- result.observe().addBoth(shuffle_along)
+ result.addBoth(shuffle_along)
return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
new file mode 100644
index 0000000000..b37f1c0725
--- /dev/null
+++ b/synapse/util/caches/stream_change_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches import cache_counter, caches_by_name
+
+
+from blist import sorteddict
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class StreamChangeCache(object):
+ """Keeps track of the stream positions of the latest change in a set of entities.
+
+ Typically the entity will be a room or user id.
+
+ Given a list of entities and a stream position, it will give a subset of
+ entities that may have changed since that position. If position key is too
+ old then the cache will simply return all given entities.
+ """
+ def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
+ self._max_size = max_size
+ self._entity_to_key = {}
+ self._cache = sorteddict()
+ self._earliest_known_stream_pos = current_stream_pos
+ self.name = name
+ caches_by_name[self.name] = self._cache
+
+ for entity, stream_pos in prefilled_cache.items():
+ self.entity_has_changed(entity, stream_pos)
+
+ def has_entity_changed(self, entity, stream_pos):
+ """Returns True if the entity may have been updated since stream_pos
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos < self._earliest_known_stream_pos:
+ cache_counter.inc_misses(self.name)
+ return True
+
+ latest_entity_change_pos = self._entity_to_key.get(entity, None)
+ if latest_entity_change_pos is None:
+ cache_counter.inc_hits(self.name)
+ return False
+
+ if stream_pos < latest_entity_change_pos:
+ cache_counter.inc_misses(self.name)
+ return True
+
+ cache_counter.inc_hits(self.name)
+ return False
+
+ def get_entities_changed(self, entities, stream_pos):
+ """Returns subset of entities that have had new things since the
+ given position. If the position is too old it will just return the given list.
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos >= self._earliest_known_stream_pos:
+ keys = self._cache.keys()
+ i = keys.bisect_right(stream_pos)
+
+ result = set(
+ self._cache[k] for k in keys[i:]
+ ).intersection(entities)
+
+ cache_counter.inc_hits(self.name)
+ else:
+ result = entities
+ cache_counter.inc_misses(self.name)
+
+ return result
+
+ def entity_has_changed(self, entity, stream_pos):
+ """Informs the cache that the entity has been changed at the given
+ position.
+ """
+ assert type(stream_pos) is int
+
+ if stream_pos > self._earliest_known_stream_pos:
+ old_pos = self._entity_to_key.get(entity, None)
+ if old_pos is not None:
+ stream_pos = max(stream_pos, old_pos)
+ self._cache.pop(old_pos, None)
+ self._cache[stream_pos] = entity
+ self._entity_to_key[entity] = stream_pos
+
+ while len(self._cache) > self._max_size:
+ k, r = self._cache.popitem()
+ self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+ self._entity_to_key.pop(r, None)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
new file mode 100644
index 0000000000..03bc1401b7
--- /dev/null
+++ b/synapse/util/caches/treecache.py
@@ -0,0 +1,92 @@
+SENTINEL = object()
+
+
+class TreeCache(object):
+ """
+ Tree-based backing store for LruCache. Allows subtrees of data to be deleted
+ efficiently.
+ Keys must be tuples.
+ """
+ def __init__(self):
+ self.size = 0
+ self.root = {}
+
+ def __setitem__(self, key, value):
+ return self.set(key, value)
+
+ def __contains__(self, key):
+ return self.get(key, SENTINEL) is not SENTINEL
+
+ def set(self, key, value):
+ node = self.root
+ for k in key[:-1]:
+ node = node.setdefault(k, {})
+ node[key[-1]] = _Entry(value)
+ self.size += 1
+
+ def get(self, key, default=None):
+ node = self.root
+ for k in key[:-1]:
+ node = node.get(k, None)
+ if node is None:
+ return default
+ return node.get(key[-1], _Entry(default)).value
+
+ def clear(self):
+ self.size = 0
+ self.root = {}
+
+ def pop(self, key, default=None):
+ nodes = []
+
+ node = self.root
+ for k in key[:-1]:
+ node = node.get(k, None)
+ nodes.append(node) # don't add the root node
+ if node is None:
+ return default
+ popped = node.pop(key[-1], SENTINEL)
+ if popped is SENTINEL:
+ return default
+
+ node_and_keys = zip(nodes, key)
+ node_and_keys.reverse()
+ node_and_keys.append((self.root, None))
+
+ for i in range(len(node_and_keys) - 1):
+ n, k = node_and_keys[i]
+
+ if n:
+ break
+ node_and_keys[i + 1][0].pop(k)
+
+ popped, cnt = _strip_and_count_entires(popped)
+ self.size -= cnt
+ return popped
+
+ def __len__(self):
+ return self.size
+
+
+class _Entry(object):
+ __slots__ = ["value"]
+
+ def __init__(self, value):
+ self.value = value
+
+
+def _strip_and_count_entires(d):
+ """Takes an _Entry or dict with leaves of _Entry's, and either returns the
+ value or a dictionary with _Entry's replaced by their values.
+
+ Also returns the count of _Entry's
+ """
+ if isinstance(d, dict):
+ cnt = 0
+ for key, value in d.items():
+ v, n = _strip_and_count_entires(value)
+ d[key] = v
+ cnt += n
+ return d, cnt
+ else:
+ return d.value, 1
|