summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/__init__.py20
-rw-r--r--synapse/util/caches/descriptors.py93
-rw-r--r--synapse/util/caches/dictionary_cache.py8
-rw-r--r--synapse/util/caches/expiringcache.py8
-rw-r--r--synapse/util/caches/response_cache.py46
-rw-r--r--synapse/util/caches/stream_change_cache.py16
6 files changed, 137 insertions, 54 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index d53569ca49..ebd715c5dc 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -24,11 +24,21 @@ DEBUG_CACHES = False
 metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
 
 caches_by_name = {}
-cache_counter = metrics.register_cache(
-    "cache",
-    lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
-    labels=["name"],
-)
+# cache_counter = metrics.register_cache(
+#     "cache",
+#     lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
+#     labels=["name"],
+# )
+
+
+def register_cache(name, cache):
+    caches_by_name[name] = cache
+    return metrics.register_cache(
+        "cache",
+        lambda: len(cache),
+        name,
+    )
+
 
 _string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
 caches_by_name["string_cache"] = _string_cache
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 35544b19fd..f31dfb22b7 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -22,7 +22,7 @@ from synapse.util.logcontext import (
     PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
 )
 
-from . import caches_by_name, DEBUG_CACHES, cache_counter
+from . import DEBUG_CACHES, register_cache
 
 from twisted.internet import defer
 
@@ -33,6 +33,7 @@ import functools
 import inspect
 import threading
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -43,6 +44,15 @@ CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
 class Cache(object):
+    __slots__ = (
+        "cache",
+        "max_entries",
+        "name",
+        "keylen",
+        "sequence",
+        "thread",
+        "metrics",
+    )
 
     def __init__(self, name, max_entries=1000, keylen=1, lru=True, tree=False):
         if lru:
@@ -59,7 +69,7 @@ class Cache(object):
         self.keylen = keylen
         self.sequence = 0
         self.thread = None
-        caches_by_name[name] = self.cache
+        self.metrics = register_cache(name, self.cache)
 
     def check_thread(self):
         expected_thread = self.thread
@@ -74,10 +84,10 @@ class Cache(object):
     def get(self, key, default=_CacheSentinel):
         val = self.cache.get(key, _CacheSentinel)
         if val is not _CacheSentinel:
-            cache_counter.inc_hits(self.name)
+            self.metrics.inc_hits()
             return val
 
-        cache_counter.inc_misses(self.name)
+        self.metrics.inc_misses()
 
         if default is _CacheSentinel:
             raise KeyError()
@@ -167,7 +177,8 @@ class CacheDescriptor(object):
                 % (orig.__name__,)
             )
 
-        self.cache = Cache(
+    def __get__(self, obj, objtype=None):
+        cache = Cache(
             name=self.orig.__name__,
             max_entries=self.max_entries,
             keylen=self.num_args,
@@ -175,14 +186,12 @@ class CacheDescriptor(object):
             tree=self.tree,
         )
 
-    def __get__(self, obj, objtype=None):
-
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
             try:
-                cached_result_d = self.cache.get(cache_key)
+                cached_result_d = cache.get(cache_key)
 
                 observer = cached_result_d.observe()
                 if DEBUG_CACHES:
@@ -204,7 +213,7 @@ class CacheDescriptor(object):
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
                 # while the SELECT is executing (SYN-369)
-                sequence = self.cache.sequence
+                sequence = cache.sequence
 
                 ret = defer.maybeDeferred(
                     preserve_context_over_fn,
@@ -213,20 +222,21 @@ class CacheDescriptor(object):
                 )
 
                 def onErr(f):
-                    self.cache.invalidate(cache_key)
+                    cache.invalidate(cache_key)
                     return f
 
                 ret.addErrback(onErr)
 
                 ret = ObservableDeferred(ret, consumeErrors=True)
-                self.cache.update(sequence, cache_key, ret)
+                cache.update(sequence, cache_key, ret)
 
                 return preserve_context_over_deferred(ret.observe())
 
-        wrapped.invalidate = self.cache.invalidate
-        wrapped.invalidate_all = self.cache.invalidate_all
-        wrapped.invalidate_many = self.cache.invalidate_many
-        wrapped.prefill = self.cache.prefill
+        wrapped.invalidate = cache.invalidate
+        wrapped.invalidate_all = cache.invalidate_all
+        wrapped.invalidate_many = cache.invalidate_many
+        wrapped.prefill = cache.prefill
+        wrapped.cache = cache
 
         obj.__dict__[self.orig.__name__] = wrapped
 
@@ -240,11 +250,12 @@ class CacheListDescriptor(object):
     the list of missing keys to the wrapped fucntion.
     """
 
-    def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
+    def __init__(self, orig, cached_method_name, list_name, num_args=1,
+                 inlineCallbacks=False):
         """
         Args:
             orig (function)
-            cache (Cache)
+            method_name (str); The name of the chached method.
             list_name (str): Name of the argument which is the bulk lookup list
             num_args (int)
             inlineCallbacks (bool): Whether orig is a generator that should
@@ -263,7 +274,7 @@ class CacheListDescriptor(object):
         self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
 
-        self.cache = cache
+        self.cached_method_name = cached_method_name
 
         self.sentinel = object()
 
@@ -277,11 +288,13 @@ class CacheListDescriptor(object):
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
-                % (self.list_name, cache.name,)
+                % (self.list_name, cached_method_name,)
             )
 
     def __get__(self, obj, objtype=None):
 
+        cache = getattr(obj, self.cached_method_name).cache
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
@@ -290,21 +303,26 @@ class CacheListDescriptor(object):
 
             # cached is a dict arg -> deferred, where deferred results in a
             # 2-tuple (`arg`, `result`)
-            cached = {}
+            results = {}
+            cached_defers = {}
             missing = []
             for arg in list_args:
                 key = list(keyargs)
                 key[self.list_pos] = arg
 
                 try:
-                    res = self.cache.get(tuple(key)).observe()
-                    res.addCallback(lambda r, arg: (arg, r), arg)
-                    cached[arg] = res
+                    res = cache.get(tuple(key))
+                    if not res.has_succeeded():
+                        res = res.observe()
+                        res.addCallback(lambda r, arg: (arg, r), arg)
+                        cached_defers[arg] = res
+                    else:
+                        results[arg] = res.get_result()
                 except KeyError:
                     missing.append(arg)
 
             if missing:
-                sequence = self.cache.sequence
+                sequence = cache.sequence
                 args_to_call = dict(arg_dict)
                 args_to_call[self.list_name] = missing
 
@@ -327,22 +345,31 @@ class CacheListDescriptor(object):
 
                     key = list(keyargs)
                     key[self.list_pos] = arg
-                    self.cache.update(sequence, tuple(key), observer)
+                    cache.update(sequence, tuple(key), observer)
 
                     def invalidate(f, key):
-                        self.cache.invalidate(key)
+                        cache.invalidate(key)
                         return f
                     observer.addErrback(invalidate, tuple(key))
 
                     res = observer.observe()
                     res.addCallback(lambda r, arg: (arg, r), arg)
 
-                    cached[arg] = res
+                    cached_defers[arg] = res
+
+            if cached_defers:
+                def update_results_dict(res):
+                    results.update(res)
+                    return results
 
-            return preserve_context_over_deferred(defer.gatherResults(
-                cached.values(),
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
+                return preserve_context_over_deferred(defer.gatherResults(
+                    cached_defers.values(),
+                    consumeErrors=True,
+                ).addCallback(update_results_dict).addErrback(
+                    unwrapFirstError
+                ))
+            else:
+                return results
 
         obj.__dict__[self.orig.__name__] = wrapped
 
@@ -370,7 +397,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False, tree=False):
     )
 
 
-def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. A single argument
@@ -400,7 +427,7 @@ def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
     """
     return lambda orig: CacheListDescriptor(
         orig,
-        cache=cache,
+        cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
         inlineCallbacks=inlineCallbacks,
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index f92d80542b..b0ca1bb79d 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -15,7 +15,7 @@
 
 from synapse.util.caches.lrucache import LruCache
 from collections import namedtuple
-from . import caches_by_name, cache_counter
+from . import register_cache
 import threading
 import logging
 
@@ -43,7 +43,7 @@ class DictionaryCache(object):
             __slots__ = []
 
         self.sentinel = Sentinel()
-        caches_by_name[name] = self.cache
+        self.metrics = register_cache(name, self.cache)
 
     def check_thread(self):
         expected_thread = self.thread
@@ -58,7 +58,7 @@ class DictionaryCache(object):
     def get(self, key, dict_keys=None):
         entry = self.cache.get(key, self.sentinel)
         if entry is not self.sentinel:
-            cache_counter.inc_hits(self.name)
+            self.metrics.inc_hits()
 
             if dict_keys is None:
                 return DictionaryEntry(entry.full, dict(entry.value))
@@ -69,7 +69,7 @@ class DictionaryCache(object):
                     if k in entry.value
                 })
 
-        cache_counter.inc_misses(self.name)
+        self.metrics.inc_misses()
         return DictionaryEntry(False, {})
 
     def invalidate(self, key):
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2b68c1ac93..080388958f 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
 
 import logging
 
@@ -49,7 +49,7 @@ class ExpiringCache(object):
 
         self._cache = {}
 
-        caches_by_name[cache_name] = self._cache
+        self.metrics = register_cache(cache_name, self._cache)
 
     def start(self):
         if not self._expiry_ms:
@@ -78,9 +78,9 @@ class ExpiringCache(object):
     def __getitem__(self, key):
         try:
             entry = self._cache[key]
-            cache_counter.inc_hits(self._cache_name)
+            self.metrics.inc_hits()
         except KeyError:
-            cache_counter.inc_misses(self._cache_name)
+            self.metrics.inc_misses()
             raise
 
         if self._reset_expiry_on_get:
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
new file mode 100644
index 0000000000..36686b479e
--- /dev/null
+++ b/synapse/util/caches/response_cache.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.async import ObservableDeferred
+
+
+class ResponseCache(object):
+    """
+    This caches a deferred response. Until the deferred completes it will be
+    returned from the cache. This means that if the client retries the request
+    while the response is still being computed, that original response will be
+    used rather than trying to compute a new response.
+    """
+
+    def __init__(self):
+        self.pending_result_cache = {}  # Requests that haven't finished yet.
+
+    def get(self, key):
+        result = self.pending_result_cache.get(key)
+        if result is not None:
+            return result.observe()
+        else:
+            return None
+
+    def set(self, key, deferred):
+        result = ObservableDeferred(deferred, consumeErrors=True)
+        self.pending_result_cache[key] = result
+
+        def remove(r):
+            self.pending_result_cache.pop(key, None)
+            return r
+
+        result.addBoth(remove)
+        return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index ea8a74ca69..3c051dabc4 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches import register_cache
 
 
 from blist import sorteddict
@@ -42,7 +42,7 @@ class StreamChangeCache(object):
         self._cache = sorteddict()
         self._earliest_known_stream_pos = current_stream_pos
         self.name = name
-        caches_by_name[self.name] = self._cache
+        self.metrics = register_cache(self.name, self._cache)
 
         for entity, stream_pos in prefilled_cache.items():
             self.entity_has_changed(entity, stream_pos)
@@ -53,19 +53,19 @@ class StreamChangeCache(object):
         assert type(stream_pos) is int
 
         if stream_pos < self._earliest_known_stream_pos:
-            cache_counter.inc_misses(self.name)
+            self.metrics.inc_misses()
             return True
 
         latest_entity_change_pos = self._entity_to_key.get(entity, None)
         if latest_entity_change_pos is None:
-            cache_counter.inc_hits(self.name)
+            self.metrics.inc_hits()
             return False
 
         if stream_pos < latest_entity_change_pos:
-            cache_counter.inc_misses(self.name)
+            self.metrics.inc_misses()
             return True
 
-        cache_counter.inc_hits(self.name)
+        self.metrics.inc_hits()
         return False
 
     def get_entities_changed(self, entities, stream_pos):
@@ -82,10 +82,10 @@ class StreamChangeCache(object):
                 self._cache[k] for k in keys[i:]
             ).intersection(entities)
 
-            cache_counter.inc_hits(self.name)
+            self.metrics.inc_hits()
         else:
             result = entities
-            cache_counter.inc_misses(self.name)
+            self.metrics.inc_misses()
 
         return result