summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/descriptors.py64
-rw-r--r--synapse/util/caches/dictionary_cache.py6
-rw-r--r--synapse/util/caches/lrucache.py15
-rw-r--r--synapse/util/caches/response_cache.py106
4 files changed, 159 insertions, 32 deletions
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index bf3a66eae4..68285a7594 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -39,12 +40,11 @@ _CacheSentinel = object()
 
 class CacheEntry(object):
     __slots__ = [
-        "deferred", "sequence", "callbacks", "invalidated"
+        "deferred", "callbacks", "invalidated"
     ]
 
-    def __init__(self, deferred, sequence, callbacks):
+    def __init__(self, deferred, callbacks):
         self.deferred = deferred
-        self.sequence = sequence
         self.callbacks = set(callbacks)
         self.invalidated = False
 
@@ -62,7 +62,6 @@ class Cache(object):
         "max_entries",
         "name",
         "keylen",
-        "sequence",
         "thread",
         "metrics",
         "_pending_deferred_cache",
@@ -80,7 +79,6 @@ class Cache(object):
 
         self.name = name
         self.keylen = keylen
-        self.sequence = 0
         self.thread = None
         self.metrics = register_cache(name, self.cache)
 
@@ -113,11 +111,10 @@ class Cache(object):
         callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _CacheSentinel)
         if val is not _CacheSentinel:
-            if val.sequence == self.sequence:
-                val.callbacks.update(callbacks)
-                if update_metrics:
-                    self.metrics.inc_hits()
-                return val.deferred
+            val.callbacks.update(callbacks)
+            if update_metrics:
+                self.metrics.inc_hits()
+            return val.deferred
 
         val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
         if val is not _CacheSentinel:
@@ -137,12 +134,9 @@ class Cache(object):
         self.check_thread()
         entry = CacheEntry(
             deferred=value,
-            sequence=self.sequence,
             callbacks=callbacks,
         )
 
-        entry.callbacks.update(callbacks)
-
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
             existing_entry.invalidate()
@@ -150,13 +144,25 @@ class Cache(object):
         self._pending_deferred_cache[key] = entry
 
         def shuffle(result):
-            if self.sequence == entry.sequence:
-                existing_entry = self._pending_deferred_cache.pop(key, None)
-                if existing_entry is entry:
-                    self.cache.set(key, result, entry.callbacks)
-                else:
-                    entry.invalidate()
+            existing_entry = self._pending_deferred_cache.pop(key, None)
+            if existing_entry is entry:
+                self.cache.set(key, result, entry.callbacks)
             else:
+                # oops, the _pending_deferred_cache has been updated since
+                # we started our query, so we are out of date.
+                #
+                # Better put back whatever we took out. (We do it this way
+                # round, rather than peeking into the _pending_deferred_cache
+                # and then removing on a match, to make the common case faster)
+                if existing_entry is not None:
+                    self._pending_deferred_cache[key] = existing_entry
+
+                # we're not going to put this entry into the cache, so need
+                # to make sure that the invalidation callbacks are called.
+                # That was probably done when _pending_deferred_cache was
+                # updated, but it's possible that `set` was called without
+                # `invalidate` being previously called, in which case it may
+                # not have been. Either way, let's double-check now.
                 entry.invalidate()
             return result
 
@@ -168,25 +174,29 @@ class Cache(object):
 
     def invalidate(self, key):
         self.check_thread()
+        self.cache.pop(key, None)
 
-        # Increment the sequence number so that any SELECT statements that
-        # raced with the INSERT don't update the cache (SYN-369)
-        self.sequence += 1
+        # if we have a pending lookup for this key, remove it from the
+        # _pending_deferred_cache, which will (a) stop it being returned
+        # for future queries and (b) stop it being persisted as a proper entry
+        # in self.cache.
         entry = self._pending_deferred_cache.pop(key, None)
+
+        # run the invalidation callbacks now, rather than waiting for the
+        # deferred to resolve.
         if entry:
             entry.invalidate()
 
-        self.cache.pop(key, None)
-
     def invalidate_many(self, key):
         self.check_thread()
         if not isinstance(key, tuple):
             raise TypeError(
                 "The cache key must be a tuple not %r" % (type(key),)
             )
-        self.sequence += 1
         self.cache.del_multi(key)
 
+        # if we have a pending lookup for this key, remove it from the
+        # _pending_deferred_cache, as above
         entry_dict = self._pending_deferred_cache.pop(key, None)
         if entry_dict is not None:
             for entry in iterate_tree_cache_entry(entry_dict):
@@ -194,8 +204,10 @@ class Cache(object):
 
     def invalidate_all(self):
         self.check_thread()
-        self.sequence += 1
         self.cache.clear()
+        for entry in self._pending_deferred_cache.itervalues():
+            entry.invalidate()
+        self._pending_deferred_cache.clear()
 
 
 class _CacheDescriptorBase(object):
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d4105822b3..1709e8b429 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -132,9 +132,13 @@ class DictionaryCache(object):
                 self._update_or_insert(key, value, known_absent)
 
     def _update_or_insert(self, key, value, known_absent):
-        entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
+        # We pop and reinsert as we need to tell the cache the size may have
+        # changed
+
+        entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
         entry.value.update(value)
         entry.known_absent.update(known_absent)
+        self.cache[key] = entry
 
     def _insert(self, key, value, known_absent):
         self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f088dd430e..1c5a982094 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -154,14 +154,21 @@ class LruCache(object):
         def cache_set(key, value, callbacks=[]):
             node = cache.get(key, None)
             if node is not None:
-                if value != node.value:
+                # We sometimes store large objects, e.g. dicts, which cause
+                # the inequality check to take a long time. So let's only do
+                # the check if we have some callbacks to call.
+                if node.callbacks and value != node.value:
                     for cb in node.callbacks:
                         cb()
                     node.callbacks.clear()
 
-                    if size_callback:
-                        cached_cache_len[0] -= size_callback(node.value)
-                        cached_cache_len[0] += size_callback(value)
+                # We don't bother to protect this by value != node.value as
+                # generally size_callback will be cheap compared with equality
+                # checks. (For example, taking the size of two dicts is quicker
+                # than comparing them for equality.)
+                if size_callback:
+                    cached_cache_len[0] -= size_callback(node.value)
+                    cached_cache_len[0] += size_callback(value)
 
                 node.callbacks.update(callbacks)
 
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 00af539880..7f79333e96 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,8 +12,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+
+from twisted.internet import defer
 
 from synapse.util.async import ObservableDeferred
+from synapse.util.caches import metrics as cache_metrics
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+logger = logging.getLogger(__name__)
 
 
 class ResponseCache(object):
@@ -24,20 +31,68 @@ class ResponseCache(object):
     used rather than trying to compute a new response.
     """
 
-    def __init__(self, hs, timeout_ms=0):
+    def __init__(self, hs, name, timeout_ms=0):
         self.pending_result_cache = {}  # Requests that haven't finished yet.
 
         self.clock = hs.get_clock()
         self.timeout_sec = timeout_ms / 1000.
 
+        self._name = name
+        self._metrics = cache_metrics.register_cache(
+            "response_cache",
+            size_callback=lambda: self.size(),
+            cache_name=name,
+        )
+
+    def size(self):
+        return len(self.pending_result_cache)
+
     def get(self, key):
+        """Look up the given key.
+
+        Can return either a new Deferred (which also doesn't follow the synapse
+        logcontext rules), or, if the request has completed, the actual
+        result. You will probably want to make_deferred_yieldable the result.
+
+        If there is no entry for the key, returns None. It is worth noting that
+        this means there is no way to distinguish a completed result of None
+        from an absent cache entry.
+
+        Args:
+            key (hashable):
+
+        Returns:
+            twisted.internet.defer.Deferred|None|E: None if there is no entry
+            for this key; otherwise either a deferred result or the result
+            itself.
+        """
         result = self.pending_result_cache.get(key)
         if result is not None:
+            self._metrics.inc_hits()
             return result.observe()
         else:
+            self._metrics.inc_misses()
             return None
 
     def set(self, key, deferred):
+        """Set the entry for the given key to the given deferred.
+
+        *deferred* should run its callbacks in the sentinel logcontext (ie,
+        you should wrap normal synapse deferreds with
+        logcontext.run_in_background).
+
+        Can return either a new Deferred (which also doesn't follow the synapse
+        logcontext rules), or, if *deferred* was already complete, the actual
+        result. You will probably want to make_deferred_yieldable the result.
+
+        Args:
+            key (hashable):
+            deferred (twisted.internet.defer.Deferred[T):
+
+        Returns:
+            twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
+                result.
+        """
         result = ObservableDeferred(deferred, consumeErrors=True)
         self.pending_result_cache[key] = result
 
@@ -53,3 +108,52 @@ class ResponseCache(object):
 
         result.addBoth(remove)
         return result.observe()
+
+    def wrap(self, key, callback, *args, **kwargs):
+        """Wrap together a *get* and *set* call, taking care of logcontexts
+
+        First looks up the key in the cache, and if it is present makes it
+        follow the synapse logcontext rules and returns it.
+
+        Otherwise, makes a call to *callback(*args, **kwargs)*, which should
+        follow the synapse logcontext rules, and adds the result to the cache.
+
+        Example usage:
+
+            @defer.inlineCallbacks
+            def handle_request(request):
+                # etc
+                defer.returnValue(result)
+
+            result = yield response_cache.wrap(
+                key,
+                handle_request,
+                request,
+            )
+
+        Args:
+            key (hashable): key to get/set in the cache
+
+            callback (callable): function to call if the key is not found in
+                the cache
+
+            *args: positional parameters to pass to the callback, if it is used
+
+            **kwargs: named paramters to pass to the callback, if it is used
+
+        Returns:
+            twisted.internet.defer.Deferred: yieldable result
+        """
+        result = self.get(key)
+        if not result:
+            logger.info("[%s]: no cached result for [%s], calculating new one",
+                        self._name, key)
+            d = run_in_background(callback, *args, **kwargs)
+            result = self.set(key, d)
+        elif not isinstance(result, defer.Deferred) or result.called:
+            logger.info("[%s]: using completed cached result for [%s]",
+                        self._name, key)
+        else:
+            logger.info("[%s]: using incomplete cached result for [%s]",
+                        self._name, key)
+        return make_deferred_yieldable(result)