summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/descriptors.py155
-rw-r--r--synapse/util/caches/dictionary_cache.py34
-rw-r--r--synapse/util/caches/expiringcache.py11
-rw-r--r--synapse/util/caches/lrucache.py2
-rw-r--r--synapse/util/caches/response_cache.py2
-rw-r--r--synapse/util/caches/snapshot_cache.py2
-rw-r--r--synapse/util/caches/stream_change_cache.py20
8 files changed, 115 insertions, 117 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 900575eb3c..7b065b195e 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -13,12 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from prometheus_client.core import Gauge, REGISTRY, GaugeMetricFamily
-
 import os
 
-from six.moves import intern
 import six
+from six.moves import intern
+
+from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily
 
 CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
 
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 65a1042de1..187510576a 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -13,10 +13,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import functools
+import inspect
 import logging
+import threading
+from collections import namedtuple
+
+import six
+from six import itervalues, string_types
+
+from twisted.internet import defer
 
-from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError, logcontext
+from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@@ -24,17 +33,6 @@ from synapse.util.stringutils import to_ascii
 
 from . import register_cache
 
-from twisted.internet import defer
-from collections import namedtuple
-
-import functools
-import inspect
-import threading
-
-from six import string_types, itervalues
-import six
-
-
 logger = logging.getLogger(__name__)
 
 
@@ -475,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
-            # If we're passed a cache_context then we'll want to call its invalidate()
-            # whenever we are invalidated
+            # If we're passed a cache_context then we'll want to call its
+            # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            # cached is a dict arg -> deferred, where deferred results in a
-            # 2-tuple (`arg`, `result`)
             results = {}
-            cached_defers = {}
-            missing = []
+
+            def update_results_dict(res, arg):
+                results[arg] = res
+
+            # list of deferreds to wait for
+            cached_defers = []
+
+            missing = set()
 
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
-                def cache_get(arg):
-                    return cache.get(arg, callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    return arg
             else:
-                key = list(keyargs)
+                keylist = list(keyargs)
 
-                def cache_get(arg):
-                    key[self.list_pos] = arg
-                    return cache.get(tuple(key), callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    keylist[self.list_pos] = arg
+                    return tuple(keylist)
 
             for arg in list_args:
                 try:
-                    res = cache_get(arg)
-
+                    res = cache.get(arg_to_cache_key(arg),
+                                    callback=invalidate_callback)
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
                         res = res.observe()
-                        res.addCallback(lambda r, arg: (arg, r), arg)
-                        cached_defers[arg] = res
+                        res.addCallback(update_results_dict, arg)
+                        cached_defers.append(res)
                     else:
                         results[arg] = res.get_result()
                 except KeyError:
-                    missing.append(arg)
+                    missing.add(arg)
 
             if missing:
+                # we need an observable deferred for each entry in the list,
+                # which we put in the cache. Each deferred resolves with the
+                # relevant result for that key.
+                deferreds_map = {}
+                for arg in missing:
+                    deferred = defer.Deferred()
+                    deferreds_map[arg] = deferred
+                    key = arg_to_cache_key(arg)
+                    observable = ObservableDeferred(deferred)
+                    cache.set(key, observable, callback=invalidate_callback)
+
+                def complete_all(res):
+                    # the wrapped function has completed. It returns a
+                    # a dict. We can now resolve the observable deferreds in
+                    # the cache and update our own result map.
+                    for e in missing:
+                        val = res.get(e, None)
+                        deferreds_map[e].callback(val)
+                        results[e] = val
+
+                def errback(f):
+                    # the wrapped function has failed. Invalidate any cache
+                    # entries we're supposed to be populating, and fail
+                    # their deferreds.
+                    for e in missing:
+                        key = arg_to_cache_key(e)
+                        cache.invalidate(key)
+                        deferreds_map[e].errback(f)
+
+                    # return the failure, to propagate to our caller.
+                    return f
+
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = list(missing)
 
-                ret_d = defer.maybeDeferred(
+                cached_defers.append(defer.maybeDeferred(
                     logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
-                )
-
-                ret_d = ObservableDeferred(ret_d)
-
-                # We need to create deferreds for each arg in the list so that
-                # we can insert the new deferred into the cache.
-                for arg in missing:
-                    observer = ret_d.observe()
-                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)
-
-                    observer = ObservableDeferred(observer)
-
-                    if num_args == 1:
-                        cache.set(
-                            arg, observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, arg)
-                    else:
-                        key = list(keyargs)
-                        key[self.list_pos] = arg
-                        cache.set(
-                            tuple(key), observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, tuple(key))
-
-                    res = observer.observe()
-                    res.addCallback(lambda r, arg: (arg, r), arg)
-
-                    cached_defers[arg] = res
+                ).addCallbacks(complete_all, errback))
 
             if cached_defers:
-                def update_results_dict(res):
-                    results.update(res)
-                    return results
-
-                return logcontext.make_deferred_yieldable(defer.gatherResults(
-                    list(cached_defers.values()),
+                d = defer.gatherResults(
+                    cached_defers,
                     consumeErrors=True,
-                ).addCallback(update_results_dict).addErrback(
+                ).addCallbacks(
+                    lambda _: results,
                     unwrapFirstError
-                ))
+                )
+                return logcontext.make_deferred_yieldable(d)
             else:
                 return results
 
@@ -627,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
     cache.
 
     Args:
-        cache (Cache): The underlying cache to use.
+        cached_method_name (str): The name of the single-item lookup method.
+            This is only used to find the cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
         num_args (int): Number of arguments to use as the key in the cache
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index bdc21e348f..6c0b5a4094 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -13,12 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches.lrucache import LruCache
-from collections import namedtuple
-from . import register_cache
-import threading
 import logging
+import threading
+from collections import namedtuple
 
+from synapse.util.caches.lrucache import LruCache
+
+from . import register_cache
 
 logger = logging.getLogger(__name__)
 
@@ -107,29 +108,28 @@ class DictionaryCache(object):
         self.sequence += 1
         self.cache.clear()
 
-    def update(self, sequence, key, value, full=False, known_absent=None):
+    def update(self, sequence, key, value, fetched_keys=None):
         """Updates the entry in the cache
 
         Args:
             sequence
-            key
-            value (dict): The value to update the cache with.
-            full (bool): Whether the given value is the full dict, or just a
-                partial subset there of. If not full then any existing entries
-                for the key will be updated.
-            known_absent (set): Set of keys that we know don't exist in the full
-                dict.
+            key (K)
+            value (dict[X,Y]): The value to update the cache with.
+            fetched_keys (None|set[X]): All of the dictionary keys which were
+                fetched from the database.
+
+                If None, this is the complete value for key K. Otherwise, it
+                is used to infer a list of keys which we know don't exist in
+                the full dict.
         """
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
-            if known_absent is None:
-                known_absent = set()
-            if full:
-                self._insert(key, value, known_absent)
+            if fetched_keys is None:
+                self._insert(key, value, set())
             else:
-                self._update_or_insert(key, value, known_absent)
+                self._update_or_insert(key, value, fetched_keys)
 
     def _update_or_insert(self, key, value, known_absent):
         # We pop and reinsert as we need to tell the cache the size may have
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index ff04c91955..ce85b2ae11 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -13,11 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches import register_cache
-
-from collections import OrderedDict
 import logging
+from collections import OrderedDict
 
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
@@ -64,7 +64,10 @@ class ExpiringCache(object):
             return
 
         def f():
-            self._prune_cache()
+            return run_as_background_process(
+                "prune_cache_%s" % self._cache_name,
+                self._prune_cache,
+            )
 
         self._clock.looping_call(f, self._expiry_ms / 2)
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 1c5a982094..b684f24e7b 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -14,8 +14,8 @@
 # limitations under the License.
 
 
-from functools import wraps
 import threading
+from functools import wraps
 
 from synapse.util.caches.treecache import TreeCache
 
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index a8491b42d5..afb03b2e1b 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -16,7 +16,7 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches import register_cache
 from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index d03678b8c8..8318db8d2c 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.async import ObservableDeferred
+from synapse.util.async_helpers import ObservableDeferred
 
 
 class SnapshotCache(object):
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 817118e30f..f2bde74dc5 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,12 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util import caches
-
+import logging
 
 from sortedcontainers import SortedDict
-import logging
 
+from synapse.util import caches
 
 logger = logging.getLogger(__name__)
 
@@ -75,13 +74,13 @@ class StreamChangeCache(object):
         assert type(stream_pos) is int
 
         if stream_pos >= self._earliest_known_stream_pos:
-            not_known_entities = set(entities) - set(self._entity_to_key)
+            changed_entities = {
+                self._cache[k] for k in self._cache.islice(
+                    start=self._cache.bisect_right(stream_pos),
+                )
+            }
 
-            result = (
-                set(self._cache.values()[self._cache.bisect_right(stream_pos) :])
-                .intersection(entities)
-                .union(not_known_entities)
-            )
+            result = changed_entities.intersection(entities)
 
             self.metrics.inc_hits()
         else:
@@ -113,7 +112,8 @@ class StreamChangeCache(object):
         assert type(stream_pos) is int
 
         if stream_pos >= self._earliest_known_stream_pos:
-            return self._cache.values()[self._cache.bisect_right(stream_pos) :]
+            return [self._cache[k] for k in self._cache.islice(
+                start=self._cache.bisect_right(stream_pos))]
         else:
             return None