summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/descriptors.py26
-rw-r--r--synapse/util/caches/expiringcache.py20
-rw-r--r--synapse/util/caches/lrucache.py10
-rw-r--r--synapse/util/caches/snapshot_cache.py3
-rw-r--r--synapse/util/caches/stream_change_cache.py123
-rw-r--r--synapse/util/caches/treecache.py38
6 files changed, 199 insertions, 21 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 88e56e3302..35544b19fd 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
 
 from . import caches_by_name, DEBUG_CACHES, cache_counter
 
@@ -25,6 +28,7 @@ from twisted.internet import defer
 
 from collections import OrderedDict
 
+import os
 import functools
 import inspect
 import threading
@@ -35,6 +39,9 @@ logger = logging.getLogger(__name__)
 _CacheSentinel = object()
 
 
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+
+
 class Cache(object):
 
     def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
@@ -137,6 +144,8 @@ class CacheDescriptor(object):
     """
     def __init__(self, orig, max_entries=1000, num_args=1, lru=True, tree=False,
                  inlineCallbacks=False):
+        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
+
         self.orig = orig
 
         if inlineCallbacks:
@@ -149,7 +158,7 @@ class CacheDescriptor(object):
         self.lru = lru
         self.tree = tree
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
 
         if len(self.arg_names) < self.num_args:
             raise Exception(
@@ -190,7 +199,7 @@ class CacheDescriptor(object):
                         defer.returnValue(cached_result)
                     observer.addCallback(check_result)
 
-                return observer
+                return preserve_context_over_deferred(observer)
             except KeyError:
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
@@ -198,6 +207,7 @@ class CacheDescriptor(object):
                 sequence = self.cache.sequence
 
                 ret = defer.maybeDeferred(
+                    preserve_context_over_fn,
                     self.function_to_call,
                     obj, *args, **kwargs
                 )
@@ -211,7 +221,7 @@ class CacheDescriptor(object):
                 ret = ObservableDeferred(ret, consumeErrors=True)
                 self.cache.update(sequence, cache_key, ret)
 
-                return ret.observe()
+                return preserve_context_over_deferred(ret.observe())
 
         wrapped.invalidate = self.cache.invalidate
         wrapped.invalidate_all = self.cache.invalidate_all
@@ -250,7 +260,7 @@ class CacheListDescriptor(object):
         self.num_args = num_args
         self.list_name = list_name
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
+        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
 
         self.cache = cache
@@ -299,6 +309,7 @@ class CacheListDescriptor(object):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
+                    preserve_context_over_fn,
                     self.function_to_call,
                     **args_to_call
                 )
@@ -308,7 +319,8 @@ class CacheListDescriptor(object):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    observer = ret_d.observe()
+                    with PreserveLoggingContext():
+                        observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
@@ -327,10 +339,10 @@ class CacheListDescriptor(object):
 
                     cached[arg] = res
 
-            return defer.gatherResults(
+            return preserve_context_over_deferred(defer.gatherResults(
                 cached.values(),
                 consumeErrors=True,
-            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
 
         obj.__dict__[self.orig.__name__] = wrapped
 
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 494226f5ea..2b68c1ac93 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.util.caches import cache_counter, caches_by_name
+
 import logging
 
 
@@ -47,6 +49,8 @@ class ExpiringCache(object):
 
         self._cache = {}
 
+        caches_by_name[cache_name] = self._cache
+
     def start(self):
         if not self._expiry_ms:
             # Don't bother starting the loop if things never expire
@@ -55,7 +59,7 @@ class ExpiringCache(object):
         def f():
             self._prune_cache()
 
-        self._clock.looping_call(f, self._expiry_ms/2)
+        self._clock.looping_call(f, self._expiry_ms / 2)
 
     def __setitem__(self, key, value):
         now = self._clock.time_msec()
@@ -65,14 +69,19 @@ class ExpiringCache(object):
         if self._max_len and len(self._cache.keys()) > self._max_len:
             sorted_entries = sorted(
                 self._cache.items(),
-                key=lambda (k, v): v.time,
+                key=lambda item: item[1].time,
             )
 
             for k, _ in sorted_entries[self._max_len:]:
                 self._cache.pop(k)
 
     def __getitem__(self, key):
-        entry = self._cache[key]
+        try:
+            entry = self._cache[key]
+            cache_counter.inc_hits(self._cache_name)
+        except KeyError:
+            cache_counter.inc_misses(self._cache_name)
+            raise
 
         if self._reset_expiry_on_get:
             entry.time = self._clock.time_msec()
@@ -105,9 +114,12 @@ class ExpiringCache(object):
 
         logger.debug(
             "[%s] _prune_cache before: %d, after len: %d",
-            self._cache_name, begin_length, len(self._cache.keys())
+            self._cache_name, begin_length, len(self._cache)
         )
 
+    def __len__(self):
+        return len(self._cache)
+
 
 class _CacheEntry(object):
     def __init__(self, time, value):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index e6a66dc041..f7423f2fab 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -37,7 +37,7 @@ class LruCache(object):
     """
     def __init__(self, max_size, keylen=1, cache_type=dict):
         cache = cache_type()
-        self.size = 0
+        self.cache = cache  # Used for introspection.
         list_root = []
         list_root[:] = [list_root, list_root, None, None]
 
@@ -60,7 +60,6 @@ class LruCache(object):
             prev_node[NEXT] = node
             next_node[PREV] = node
             cache[key] = node
-            self.size += 1
 
         def move_node_to_front(node):
             prev_node = node[PREV]
@@ -79,7 +78,6 @@ class LruCache(object):
             next_node = node[NEXT]
             prev_node[NEXT] = next_node
             next_node[PREV] = prev_node
-            self.size -= 1
 
         @synchronized
         def cache_get(key, default=None):
@@ -98,7 +96,7 @@ class LruCache(object):
                 node[VALUE] = value
             else:
                 add_node(key, value)
-                if self.size > max_size:
+                if len(cache) > max_size:
                     todelete = list_root[PREV]
                     delete_node(todelete)
                     cache.pop(todelete[KEY], None)
@@ -110,7 +108,7 @@ class LruCache(object):
                 return node[VALUE]
             else:
                 add_node(key, value)
-                if self.size > max_size:
+                if len(cache) > max_size:
                     todelete = list_root[PREV]
                     delete_node(todelete)
                     cache.pop(todelete[KEY], None)
@@ -145,7 +143,7 @@ class LruCache(object):
 
         @synchronized
         def cache_len():
-            return self.size
+            return len(cache)
 
         @synchronized
         def cache_contains(key):
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index b1e40417fd..d03678b8c8 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -87,7 +87,8 @@ class SnapshotCache(object):
             # expire from the rotation of that cache.
             self.next_result_cache[key] = result
             self.pending_result_cache.pop(key, None)
+            return r
 
-        result.observe().addBoth(shuffle_along)
+        result.addBoth(shuffle_along)
 
         return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
new file mode 100644
index 0000000000..ea8a74ca69
--- /dev/null
+++ b/synapse/util/caches/stream_change_cache.py
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches import cache_counter, caches_by_name
+
+
+from blist import sorteddict
+import logging
+import os
+
+
+logger = logging.getLogger(__name__)
+
+
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+
+
+class StreamChangeCache(object):
+    """Keeps track of the stream positions of the latest change in a set of entities.
+
+    Typically the entity will be a room or user id.
+
+    Given a list of entities and a stream position, it will give a subset of
+    entities that may have changed since that position. If position key is too
+    old then the cache will simply return all given entities.
+    """
+    def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
+        self._max_size = int(max_size * CACHE_SIZE_FACTOR)
+        self._entity_to_key = {}
+        self._cache = sorteddict()
+        self._earliest_known_stream_pos = current_stream_pos
+        self.name = name
+        caches_by_name[self.name] = self._cache
+
+        for entity, stream_pos in prefilled_cache.items():
+            self.entity_has_changed(entity, stream_pos)
+
+    def has_entity_changed(self, entity, stream_pos):
+        """Returns True if the entity may have been updated since stream_pos
+        """
+        assert type(stream_pos) is int
+
+        if stream_pos < self._earliest_known_stream_pos:
+            cache_counter.inc_misses(self.name)
+            return True
+
+        latest_entity_change_pos = self._entity_to_key.get(entity, None)
+        if latest_entity_change_pos is None:
+            cache_counter.inc_hits(self.name)
+            return False
+
+        if stream_pos < latest_entity_change_pos:
+            cache_counter.inc_misses(self.name)
+            return True
+
+        cache_counter.inc_hits(self.name)
+        return False
+
+    def get_entities_changed(self, entities, stream_pos):
+        """Returns subset of entities that have had new things since the
+        given position. If the position is too old it will just return the given list.
+        """
+        assert type(stream_pos) is int
+
+        if stream_pos >= self._earliest_known_stream_pos:
+            keys = self._cache.keys()
+            i = keys.bisect_right(stream_pos)
+
+            result = set(
+                self._cache[k] for k in keys[i:]
+            ).intersection(entities)
+
+            cache_counter.inc_hits(self.name)
+        else:
+            result = entities
+            cache_counter.inc_misses(self.name)
+
+        return result
+
+    def get_all_entities_changed(self, stream_pos):
+        """Returns all entites that have had new things since the given
+        position. If the position is too old it will return None.
+        """
+        assert type(stream_pos) is int
+
+        if stream_pos >= self._earliest_known_stream_pos:
+            keys = self._cache.keys()
+            i = keys.bisect_right(stream_pos)
+
+            return [self._cache[k] for k in keys[i:]]
+        else:
+            return None
+
+    def entity_has_changed(self, entity, stream_pos):
+        """Informs the cache that the entity has been changed at the given
+        position.
+        """
+        assert type(stream_pos) is int
+
+        if stream_pos > self._earliest_known_stream_pos:
+            old_pos = self._entity_to_key.get(entity, None)
+            if old_pos is not None:
+                stream_pos = max(stream_pos, old_pos)
+                self._cache.pop(old_pos, None)
+            self._cache[stream_pos] = entity
+            self._entity_to_key[entity] = stream_pos
+
+            while len(self._cache) > self._max_size:
+                k, r = self._cache.popitem()
+                self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+                self._entity_to_key.pop(r, None)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index 3b58860910..03bc1401b7 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -8,6 +8,7 @@ class TreeCache(object):
     Keys must be tuples.
     """
     def __init__(self):
+        self.size = 0
         self.root = {}
 
     def __setitem__(self, key, value):
@@ -20,7 +21,8 @@ class TreeCache(object):
         node = self.root
         for k in key[:-1]:
             node = node.setdefault(k, {})
-        node[key[-1]] = value
+        node[key[-1]] = _Entry(value)
+        self.size += 1
 
     def get(self, key, default=None):
         node = self.root
@@ -28,9 +30,10 @@ class TreeCache(object):
             node = node.get(k, None)
             if node is None:
                 return default
-        return node.get(key[-1], default)
+        return node.get(key[-1], _Entry(default)).value
 
     def clear(self):
+        self.size = 0
         self.root = {}
 
     def pop(self, key, default=None):
@@ -55,6 +58,35 @@ class TreeCache(object):
 
             if n:
                 break
-            node_and_keys[i+1][0].pop(k)
+            node_and_keys[i + 1][0].pop(k)
 
+        popped, cnt = _strip_and_count_entires(popped)
+        self.size -= cnt
         return popped
+
+    def __len__(self):
+        return self.size
+
+
+class _Entry(object):
+    __slots__ = ["value"]
+
+    def __init__(self, value):
+        self.value = value
+
+
+def _strip_and_count_entires(d):
+    """Takes an _Entry or dict with leaves of _Entry's, and either returns the
+    value or a dictionary with _Entry's replaced by their values.
+
+    Also returns the count of _Entry's
+    """
+    if isinstance(d, dict):
+        cnt = 0
+        for key, value in d.items():
+            v, n = _strip_and_count_entires(value)
+            d[key] = v
+            cnt += n
+        return d, cnt
+    else:
+        return d.value, 1