summary refs log tree commit diff
path: root/synapse/util/caches
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/caches')
-rw-r--r--synapse/util/caches/__init__.py33
-rw-r--r--synapse/util/caches/descriptors.py412
-rw-r--r--synapse/util/caches/dictionary_cache.py61
-rw-r--r--synapse/util/caches/expiringcache.py16
-rw-r--r--synapse/util/caches/lrucache.py43
-rw-r--r--synapse/util/caches/response_cache.py106
-rw-r--r--synapse/util/caches/stream_change_cache.py23
7 files changed, 492 insertions, 202 deletions
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8a7774a88e..4adae96681 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -14,12 +14,9 @@
 # limitations under the License.
 
 import synapse.metrics
-from lrucache import LruCache
 import os
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-DEBUG_CACHES = False
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
 
 metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
 
@@ -40,10 +37,6 @@ def register_cache(name, cache):
     )
 
 
-_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
-_stirng_cache_metrics = register_cache("string_cache", _string_cache)
-
-
 KNOWN_KEYS = {
     key: key for key in
     (
@@ -67,14 +60,16 @@ KNOWN_KEYS = {
 
 
 def intern_string(string):
-    """Takes a (potentially) unicode string and interns using custom cache
+    """Takes a (potentially) unicode string and interns it if it's ascii
     """
-    new_str = _string_cache.setdefault(string, string)
-    if new_str is string:
-        _stirng_cache_metrics.inc_hits()
-    else:
-        _stirng_cache_metrics.inc_misses()
-    return new_str
+    if string is None:
+        return None
+
+    try:
+        string = string.encode("ascii")
+        return intern(string)
+    except UnicodeEncodeError:
+        return string
 
 
 def intern_dict(dictionary):
@@ -87,13 +82,9 @@ def intern_dict(dictionary):
 
 
 def _intern_known_values(key, value):
-    intern_str_keys = ("event_id", "room_id")
-    intern_unicode_keys = ("sender", "user_id", "type", "state_key")
-
-    if key in intern_str_keys:
-        return intern(value.encode('ascii'))
+    intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
 
-    if key in intern_unicode_keys:
+    if key in intern_keys:
         return intern_string(value)
 
     return value
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 998de70d29..68285a7594 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,19 +16,17 @@
 import logging
 
 from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
+from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
+from synapse.util.stringutils import to_ascii
 
-from . import DEBUG_CACHES, register_cache
+from . import register_cache
 
 from twisted.internet import defer
 from collections import namedtuple
 
-import os
 import functools
 import inspect
 import threading
@@ -39,17 +38,13 @@ logger = logging.getLogger(__name__)
 _CacheSentinel = object()
 
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
 class CacheEntry(object):
     __slots__ = [
-        "deferred", "sequence", "callbacks", "invalidated"
+        "deferred", "callbacks", "invalidated"
     ]
 
-    def __init__(self, deferred, sequence, callbacks):
+    def __init__(self, deferred, callbacks):
         self.deferred = deferred
-        self.sequence = sequence
         self.callbacks = set(callbacks)
         self.invalidated = False
 
@@ -67,7 +62,6 @@ class Cache(object):
         "max_entries",
         "name",
         "keylen",
-        "sequence",
         "thread",
         "metrics",
         "_pending_deferred_cache",
@@ -79,15 +73,18 @@ class Cache(object):
 
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
-            size_callback=(lambda d: len(d.result)) if iterable else None,
+            size_callback=(lambda d: len(d)) if iterable else None,
+            evicted_callback=self._on_evicted,
         )
 
         self.name = name
         self.keylen = keylen
-        self.sequence = 0
         self.thread = None
         self.metrics = register_cache(name, self.cache)
 
+    def _on_evicted(self, evicted_count):
+        self.metrics.inc_evictions(evicted_count)
+
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
@@ -98,21 +95,34 @@ class Cache(object):
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, key, default=_CacheSentinel, callback=None):
+    def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
+        """Looks the key up in the caches.
+
+        Args:
+            key(tuple)
+            default: What is returned if key is not in the caches. If not
+                specified then function throws KeyError instead
+            callback(fn): Gets called when the entry in the cache is invalidated
+            update_metrics (bool): whether to update the cache hit rate metrics
+
+        Returns:
+            Either a Deferred or the raw result
+        """
         callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _CacheSentinel)
         if val is not _CacheSentinel:
-            if val.sequence == self.sequence:
-                val.callbacks.update(callbacks)
+            val.callbacks.update(callbacks)
+            if update_metrics:
                 self.metrics.inc_hits()
-                return val.deferred
+            return val.deferred
 
         val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
         if val is not _CacheSentinel:
             self.metrics.inc_hits()
             return val
 
-        self.metrics.inc_misses()
+        if update_metrics:
+            self.metrics.inc_misses()
 
         if default is _CacheSentinel:
             raise KeyError()
@@ -124,12 +134,9 @@ class Cache(object):
         self.check_thread()
         entry = CacheEntry(
             deferred=value,
-            sequence=self.sequence,
             callbacks=callbacks,
         )
 
-        entry.callbacks.update(callbacks)
-
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
             existing_entry.invalidate()
@@ -137,13 +144,25 @@ class Cache(object):
         self._pending_deferred_cache[key] = entry
 
         def shuffle(result):
-            if self.sequence == entry.sequence:
-                existing_entry = self._pending_deferred_cache.pop(key, None)
-                if existing_entry is entry:
-                    self.cache.set(key, entry.deferred, entry.callbacks)
-                else:
-                    entry.invalidate()
+            existing_entry = self._pending_deferred_cache.pop(key, None)
+            if existing_entry is entry:
+                self.cache.set(key, result, entry.callbacks)
             else:
+                # oops, the _pending_deferred_cache has been updated since
+                # we started our query, so we are out of date.
+                #
+                # Better put back whatever we took out. (We do it this way
+                # round, rather than peeking into the _pending_deferred_cache
+                # and then removing on a match, to make the common case faster)
+                if existing_entry is not None:
+                    self._pending_deferred_cache[key] = existing_entry
+
+                # we're not going to put this entry into the cache, so need
+                # to make sure that the invalidation callbacks are called.
+                # That was probably done when _pending_deferred_cache was
+                # updated, but it's possible that `set` was called without
+                # `invalidate` being previously called, in which case it may
+                # not have been. Either way, let's double-check now.
                 entry.invalidate()
             return result
 
@@ -155,29 +174,29 @@ class Cache(object):
 
     def invalidate(self, key):
         self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError(
-                "The cache key must be a tuple not %r" % (type(key),)
-            )
+        self.cache.pop(key, None)
 
-        # Increment the sequence number so that any SELECT statements that
-        # raced with the INSERT don't update the cache (SYN-369)
-        self.sequence += 1
+        # if we have a pending lookup for this key, remove it from the
+        # _pending_deferred_cache, which will (a) stop it being returned
+        # for future queries and (b) stop it being persisted as a proper entry
+        # in self.cache.
         entry = self._pending_deferred_cache.pop(key, None)
+
+        # run the invalidation callbacks now, rather than waiting for the
+        # deferred to resolve.
         if entry:
             entry.invalidate()
 
-        self.cache.pop(key, None)
-
     def invalidate_many(self, key):
         self.check_thread()
         if not isinstance(key, tuple):
             raise TypeError(
                 "The cache key must be a tuple not %r" % (type(key),)
             )
-        self.sequence += 1
         self.cache.del_multi(key)
 
+        # if we have a pending lookup for this key, remove it from the
+        # _pending_deferred_cache, as above
         entry_dict = self._pending_deferred_cache.pop(key, None)
         if entry_dict is not None:
             for entry in iterate_tree_cache_entry(entry_dict):
@@ -185,11 +204,73 @@ class Cache(object):
 
     def invalidate_all(self):
         self.check_thread()
-        self.sequence += 1
         self.cache.clear()
+        for entry in self._pending_deferred_cache.itervalues():
+            entry.invalidate()
+        self._pending_deferred_cache.clear()
+
+
+class _CacheDescriptorBase(object):
+    def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+        self.orig = orig
+
+        if inlineCallbacks:
+            self.function_to_call = defer.inlineCallbacks(orig)
+        else:
+            self.function_to_call = orig
 
+        arg_spec = inspect.getargspec(orig)
+        all_args = arg_spec.args
 
-class CacheDescriptor(object):
+        if "cache_context" in all_args:
+            if not cache_context:
+                raise ValueError(
+                    "Cannot have a 'cache_context' arg without setting"
+                    " cache_context=True"
+                )
+        elif cache_context:
+            raise ValueError(
+                "Cannot have cache_context=True without having an arg"
+                " named `cache_context`"
+            )
+
+        if num_args is None:
+            num_args = len(all_args) - 1
+            if cache_context:
+                num_args -= 1
+
+        if len(all_args) < num_args + 1:
+            raise Exception(
+                "Not enough explicit positional arguments to key off for %r: "
+                "got %i args, but wanted %i. (@cached cannot key off *args or "
+                "**kwargs)"
+                % (orig.__name__, len(all_args), num_args)
+            )
+
+        self.num_args = num_args
+
+        # list of the names of the args used as the cache key
+        self.arg_names = all_args[1:num_args + 1]
+
+        # self.arg_defaults is a map of arg name to its default value for each
+        # argument that has a default value
+        if arg_spec.defaults:
+            self.arg_defaults = dict(zip(
+                all_args[-len(arg_spec.defaults):],
+                arg_spec.defaults
+            ))
+        else:
+            self.arg_defaults = {}
+
+        if "cache_context" in self.arg_names:
+            raise Exception(
+                "cache_context arg cannot be included among the cache keys"
+            )
+
+        self.add_cache_context = cache_context
+
+
+class CacheDescriptor(_CacheDescriptorBase):
     """ A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
@@ -217,52 +298,24 @@ class CacheDescriptor(object):
             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
             defer.returnValue(r1 + r2)
 
+    Args:
+        num_args (int): number of positional arguments (excluding ``self`` and
+            ``cache_context``) to use as cache keys. Defaults to all named
+            args of the function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+    def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
                  inlineCallbacks=False, cache_context=False, iterable=False):
-        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
-        self.orig = orig
+        super(CacheDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
+            cache_context=cache_context)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
+        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.max_entries = max_entries
-        self.num_args = num_args
         self.tree = tree
-
         self.iterable = iterable
 
-        all_args = inspect.getargspec(orig)
-        self.arg_names = all_args.args[1:num_args + 1]
-
-        if "cache_context" in all_args.args:
-            if not cache_context:
-                raise ValueError(
-                    "Cannot have a 'cache_context' arg without setting"
-                    " cache_context=True"
-                )
-            try:
-                self.arg_names.remove("cache_context")
-            except ValueError:
-                pass
-        elif cache_context:
-            raise ValueError(
-                "Cannot have cache_context=True without having an arg"
-                " named `cache_context`"
-            )
-
-        self.add_cache_context = cache_context
-
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwargs)"
-                % (orig.__name__,)
-            )
-
     def __get__(self, obj, objtype=None):
         cache = Cache(
             name=self.orig.__name__,
@@ -272,18 +325,47 @@ class CacheDescriptor(object):
             iterable=self.iterable,
         )
 
+        def get_cache_key_gen(args, kwargs):
+            """Given some args/kwargs return a generator that resolves into
+            the cache_key.
+
+            We loop through each arg name, looking up if its in the `kwargs`,
+            otherwise using the next argument in `args`. If there are no more
+            args then we try looking the arg name up in the defaults
+            """
+            pos = 0
+            for nm in self.arg_names:
+                if nm in kwargs:
+                    yield kwargs[nm]
+                elif pos < len(args):
+                    yield args[pos]
+                    pos += 1
+                else:
+                    yield self.arg_defaults[nm]
+
+        # By default our cache key is a tuple, but if there is only one item
+        # then don't bother wrapping in a tuple.  This is to save memory.
+        if self.num_args == 1:
+            nm = self.arg_names[0]
+
+            def get_cache_key(args, kwargs):
+                if nm in kwargs:
+                    return kwargs[nm]
+                elif len(args):
+                    return args[0]
+                else:
+                    return self.arg_defaults[nm]
+        else:
+            def get_cache_key(args, kwargs):
+                return tuple(get_cache_key_gen(args, kwargs))
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add temp cache_context so inspect.getcallargs doesn't explode
-            if self.add_cache_context:
-                kwargs["cache_context"] = None
-
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            cache_key = get_cache_key(args, kwargs)
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -293,26 +375,14 @@ class CacheDescriptor(object):
             try:
                 cached_result_d = cache.get(cache_key, callback=invalidate_callback)
 
-                observer = cached_result_d.observe()
-                if DEBUG_CACHES:
-                    @defer.inlineCallbacks
-                    def check_result(cached_result):
-                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
-                        if actual_result != cached_result:
-                            logger.error(
-                                "Stale cache entry %s%r: cached: %r, actual %r",
-                                self.orig.__name__, cache_key,
-                                cached_result, actual_result,
-                            )
-                            raise ValueError("Stale cache entry")
-                        defer.returnValue(cached_result)
-                    observer.addCallback(check_result)
-
-                return preserve_context_over_deferred(observer)
+                if isinstance(cached_result_d, ObservableDeferred):
+                    observer = cached_result_d.observe()
+                else:
+                    observer = cached_result_d
+
             except KeyError:
                 ret = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     obj, *args, **kwargs
                 )
 
@@ -322,64 +392,72 @@ class CacheDescriptor(object):
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.set(cache_key, ret, callback=invalidate_callback)
+                # If our cache_key is a string, try to convert to ascii to save
+                # a bit of space in large caches
+                if isinstance(cache_key, basestring):
+                    cache_key = to_ascii(cache_key)
 
-                return preserve_context_over_deferred(ret.observe())
+                result_d = ObservableDeferred(ret, consumeErrors=True)
+                cache.set(cache_key, result_d, callback=invalidate_callback)
+                observer = result_d.observe()
+
+            if isinstance(observer, defer.Deferred):
+                return logcontext.make_deferred_yieldable(observer)
+            else:
+                return observer
+
+        if self.num_args == 1:
+            wrapped.invalidate = lambda key: cache.invalidate(key[0])
+            wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
+        else:
+            wrapped.invalidate = cache.invalidate
+            wrapped.invalidate_all = cache.invalidate_all
+            wrapped.invalidate_many = cache.invalidate_many
+            wrapped.prefill = cache.prefill
 
-        wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
-        wrapped.invalidate_many = cache.invalidate_many
-        wrapped.prefill = cache.prefill
         wrapped.cache = cache
+        wrapped.num_args = self.num_args
 
         obj.__dict__[self.orig.__name__] = wrapped
 
         return wrapped
 
 
-class CacheListDescriptor(object):
+class CacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
-    the list of missing keys to the wrapped fucntion.
+    the list of missing keys to the wrapped function.
+
+    Once wrapped, the function returns either a Deferred which resolves to
+    the list of results, or (if all results were cached), just the list of
+    results.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=1,
+    def __init__(self, orig, cached_method_name, list_name, num_args=None,
                  inlineCallbacks=False):
         """
         Args:
             orig (function)
-            method_name (str); The name of the chached method.
+            cached_method_name (str): The name of the chached method.
             list_name (str): Name of the argument which is the bulk lookup list
-            num_args (int)
+            num_args (int): number of positional arguments (excluding ``self``,
+                but including list_name) to use as cache keys. Defaults to all
+                named args of the function.
             inlineCallbacks (bool): Whether orig is a generator that should
                 be wrapped by defer.inlineCallbacks
         """
-        self.orig = orig
+        super(CacheListDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
-        self.num_args = num_args
         self.list_name = list_name
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
-
         self.cached_method_name = cached_method_name
 
         self.sentinel = object()
 
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
-                % (orig.__name__,)
-            )
-
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
@@ -387,8 +465,9 @@ class CacheListDescriptor(object):
             )
 
     def __get__(self, obj, objtype=None):
-
-        cache = getattr(obj, self.cached_method_name).cache
+        cached_method = getattr(obj, self.cached_method_name)
+        cache = cached_method.cache
+        num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
@@ -405,13 +484,26 @@ class CacheListDescriptor(object):
             results = {}
             cached_defers = {}
             missing = []
-            for arg in list_args:
+
+            # If the cache takes a single arg then that is used as the key,
+            # otherwise a tuple is used.
+            if num_args == 1:
+                def cache_get(arg):
+                    return cache.get(arg, callback=invalidate_callback)
+            else:
                 key = list(keyargs)
-                key[self.list_pos] = arg
 
+                def cache_get(arg):
+                    key[self.list_pos] = arg
+                    return cache.get(tuple(key), callback=invalidate_callback)
+
+            for arg in list_args:
                 try:
-                    res = cache.get(tuple(key), callback=invalidate_callback)
-                    if not res.has_succeeded():
+                    res = cache_get(arg)
+
+                    if not isinstance(res, ObservableDeferred):
+                        results[arg] = res
+                    elif not res.has_succeeded():
                         res = res.observe()
                         res.addCallback(lambda r, arg: (arg, r), arg)
                         cached_defers[arg] = res
@@ -425,8 +517,7 @@ class CacheListDescriptor(object):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
                 )
 
@@ -435,23 +526,33 @@ class CacheListDescriptor(object):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    with PreserveLoggingContext():
-                        observer = ret_d.observe()
+                    observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
 
-                    key = list(keyargs)
-                    key[self.list_pos] = arg
-                    cache.set(
-                        tuple(key), observer,
-                        callback=invalidate_callback
-                    )
+                    if num_args == 1:
+                        cache.set(
+                            arg, observer,
+                            callback=invalidate_callback
+                        )
 
-                    def invalidate(f, key):
-                        cache.invalidate(key)
-                        return f
-                    observer.addErrback(invalidate, tuple(key))
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, arg)
+                    else:
+                        key = list(keyargs)
+                        key[self.list_pos] = arg
+                        cache.set(
+                            tuple(key), observer,
+                            callback=invalidate_callback
+                        )
+
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, tuple(key))
 
                     res = observer.observe()
                     res.addCallback(lambda r, arg: (arg, r), arg)
@@ -463,7 +564,7 @@ class CacheListDescriptor(object):
                     results.update(res)
                     return results
 
-                return preserve_context_over_deferred(defer.gatherResults(
+                return logcontext.make_deferred_yieldable(defer.gatherResults(
                     cached_defers.values(),
                     consumeErrors=True,
                 ).addCallback(update_results_dict).addErrback(
@@ -487,7 +588,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
            iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
@@ -499,8 +600,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
-                          iterable=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
+                          cache_context=False, iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -512,7 +613,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
     )
 
 
-def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. A single argument
@@ -525,7 +626,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
         cache (Cache): The underlying cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
-        num_args (int): Number of arguments to use as the key in the cache.
+        num_args (int): Number of arguments to use as the key in the cache
+            (including list_name). Defaults to all named parameters.
         inlineCallbacks (bool): Should the function be wrapped in an
             `defer.inlineCallbacks`?
 
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index cb6933c61c..1709e8b429 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -23,7 +23,17 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
+class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
+    """Returned when getting an entry from the cache
+
+    Attributes:
+        full (bool): Whether the cache has the full or dict or just some keys.
+            If not full then not all requested keys will necessarily be present
+            in `value`
+        known_absent (set): Keys that were looked up in the dict and were not
+            there.
+        value (dict): The full or partial dict value
+    """
     def __len__(self):
         return len(self.value)
 
@@ -58,21 +68,31 @@ class DictionaryCache(object):
                 )
 
     def get(self, key, dict_keys=None):
+        """Fetch an entry out of the cache
+
+        Args:
+            key
+            dict_key(list): If given a set of keys then return only those keys
+                that exist in the cache.
+
+        Returns:
+            DictionaryEntry
+        """
         entry = self.cache.get(key, self.sentinel)
         if entry is not self.sentinel:
             self.metrics.inc_hits()
 
             if dict_keys is None:
-                return DictionaryEntry(entry.full, dict(entry.value))
+                return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
             else:
-                return DictionaryEntry(entry.full, {
+                return DictionaryEntry(entry.full, entry.known_absent, {
                     k: entry.value[k]
                     for k in dict_keys
                     if k in entry.value
                 })
 
         self.metrics.inc_misses()
-        return DictionaryEntry(False, {})
+        return DictionaryEntry(False, set(), {})
 
     def invalidate(self, key):
         self.check_thread()
@@ -87,19 +107,38 @@ class DictionaryCache(object):
         self.sequence += 1
         self.cache.clear()
 
-    def update(self, sequence, key, value, full=False):
+    def update(self, sequence, key, value, full=False, known_absent=None):
+        """Updates the entry in the cache
+
+        Args:
+            sequence
+            key
+            value (dict): The value to update the cache with.
+            full (bool): Whether the given value is the full dict, or just a
+                partial subset there of. If not full then any existing entries
+                for the key will be updated.
+            known_absent (set): Set of keys that we know don't exist in the full
+                dict.
+        """
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
+            if known_absent is None:
+                known_absent = set()
             if full:
-                self._insert(key, value)
+                self._insert(key, value, known_absent)
             else:
-                self._update_or_insert(key, value)
+                self._update_or_insert(key, value, known_absent)
+
+    def _update_or_insert(self, key, value, known_absent):
+        # We pop and reinsert as we need to tell the cache the size may have
+        # changed
 
-    def _update_or_insert(self, key, value):
-        entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
+        entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
         entry.value.update(value)
+        entry.known_absent.update(known_absent)
+        self.cache[key] = entry
 
-    def _insert(self, key, value):
-        self.cache[key] = DictionaryEntry(True, value)
+    def _insert(self, key, value, known_absent):
+        self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2987c38a2d..0aa103eecb 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -79,7 +79,11 @@ class ExpiringCache(object):
         while self._max_len and len(self) > self._max_len:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self._size_estimate -= len(value.value)
+                removed_len = len(value.value)
+                self.metrics.inc_evictions(removed_len)
+                self._size_estimate -= removed_len
+            else:
+                self.metrics.inc_evictions()
 
     def __getitem__(self, key):
         try:
@@ -94,12 +98,22 @@ class ExpiringCache(object):
 
         return entry.value
 
+    def __contains__(self, key):
+        return key in self._cache
+
     def get(self, key, default=None):
         try:
             return self[key]
         except KeyError:
             return default
 
+    def setdefault(self, key, value):
+        try:
+            return self[key]
+        except KeyError:
+            self[key] = value
+            return value
+
     def _prune_cache(self):
         if not self._expiry_ms:
             # zero expiry time means don't expire. This should never get called
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..1c5a982094 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,24 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+                 evicted_callback=None):
+        """
+        Args:
+            max_size (int):
+
+            keylen (int):
+
+            cache_type (type):
+                type of underlying cache to be used. Typically one of dict
+                or TreeCache.
+
+            size_callback (func(V) -> int | None):
+
+            evicted_callback (func(int)|None):
+                if not None, called on eviction with the size of the evicted
+                entry
+        """
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
         def evict():
             while cache_len() > max_size:
                 todelete = list_root.prev_node
-                delete_node(todelete)
+                evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
+                if evicted_callback:
+                    evicted_callback(evicted_len)
 
         def synchronized(f):
             @wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
             prev_node.next_node = next_node
             next_node.prev_node = prev_node
 
+            deleted_len = 1
             if size_callback:
-                cached_cache_len[0] -= size_callback(node.value)
+                deleted_len = size_callback(node.value)
+                cached_cache_len[0] -= deleted_len
 
             for cb in node.callbacks:
                 cb()
             node.callbacks.clear()
+            return deleted_len
 
         @synchronized
         def cache_get(key, default=None, callbacks=[]):
@@ -132,14 +154,21 @@ class LruCache(object):
         def cache_set(key, value, callbacks=[]):
             node = cache.get(key, None)
             if node is not None:
-                if value != node.value:
+                # We sometimes store large objects, e.g. dicts, which cause
+                # the inequality check to take a long time. So let's only do
+                # the check if we have some callbacks to call.
+                if node.callbacks and value != node.value:
                     for cb in node.callbacks:
                         cb()
                     node.callbacks.clear()
 
-                    if size_callback:
-                        cached_cache_len[0] -= size_callback(node.value)
-                        cached_cache_len[0] += size_callback(value)
+                # We don't bother to protect this by value != node.value as
+                # generally size_callback will be cheap compared with equality
+                # checks. (For example, taking the size of two dicts is quicker
+                # than comparing them for equality.)
+                if size_callback:
+                    cached_cache_len[0] -= size_callback(node.value)
+                    cached_cache_len[0] += size_callback(value)
 
                 node.callbacks.update(callbacks)
 
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 00af539880..7f79333e96 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,8 +12,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+
+from twisted.internet import defer
 
 from synapse.util.async import ObservableDeferred
+from synapse.util.caches import metrics as cache_metrics
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+logger = logging.getLogger(__name__)
 
 
 class ResponseCache(object):
@@ -24,20 +31,68 @@ class ResponseCache(object):
     used rather than trying to compute a new response.
     """
 
-    def __init__(self, hs, timeout_ms=0):
+    def __init__(self, hs, name, timeout_ms=0):
         self.pending_result_cache = {}  # Requests that haven't finished yet.
 
         self.clock = hs.get_clock()
         self.timeout_sec = timeout_ms / 1000.
 
+        self._name = name
+        self._metrics = cache_metrics.register_cache(
+            "response_cache",
+            size_callback=lambda: self.size(),
+            cache_name=name,
+        )
+
+    def size(self):
+        return len(self.pending_result_cache)
+
     def get(self, key):
+        """Look up the given key.
+
+        Can return either a new Deferred (which also doesn't follow the synapse
+        logcontext rules), or, if the request has completed, the actual
+        result. You will probably want to make_deferred_yieldable the result.
+
+        If there is no entry for the key, returns None. It is worth noting that
+        this means there is no way to distinguish a completed result of None
+        from an absent cache entry.
+
+        Args:
+            key (hashable):
+
+        Returns:
+            twisted.internet.defer.Deferred|None|E: None if there is no entry
+            for this key; otherwise either a deferred result or the result
+            itself.
+        """
         result = self.pending_result_cache.get(key)
         if result is not None:
+            self._metrics.inc_hits()
             return result.observe()
         else:
+            self._metrics.inc_misses()
             return None
 
     def set(self, key, deferred):
+        """Set the entry for the given key to the given deferred.
+
+        *deferred* should run its callbacks in the sentinel logcontext (ie,
+        you should wrap normal synapse deferreds with
+        logcontext.run_in_background).
+
+        Can return either a new Deferred (which also doesn't follow the synapse
+        logcontext rules), or, if *deferred* was already complete, the actual
+        result. You will probably want to make_deferred_yieldable the result.
+
+        Args:
+            key (hashable):
+            deferred (twisted.internet.defer.Deferred[T):
+
+        Returns:
+            twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
+                result.
+        """
         result = ObservableDeferred(deferred, consumeErrors=True)
         self.pending_result_cache[key] = result
 
@@ -53,3 +108,52 @@ class ResponseCache(object):
 
         result.addBoth(remove)
         return result.observe()
+
+    def wrap(self, key, callback, *args, **kwargs):
+        """Wrap together a *get* and *set* call, taking care of logcontexts
+
+        First looks up the key in the cache, and if it is present makes it
+        follow the synapse logcontext rules and returns it.
+
+        Otherwise, makes a call to *callback(*args, **kwargs)*, which should
+        follow the synapse logcontext rules, and adds the result to the cache.
+
+        Example usage:
+
+            @defer.inlineCallbacks
+            def handle_request(request):
+                # etc
+                defer.returnValue(result)
+
+            result = yield response_cache.wrap(
+                key,
+                handle_request,
+                request,
+            )
+
+        Args:
+            key (hashable): key to get/set in the cache
+
+            callback (callable): function to call if the key is not found in
+                the cache
+
+            *args: positional parameters to pass to the callback, if it is used
+
+            **kwargs: named paramters to pass to the callback, if it is used
+
+        Returns:
+            twisted.internet.defer.Deferred: yieldable result
+        """
+        result = self.get(key)
+        if not result:
+            logger.info("[%s]: no cached result for [%s], calculating new one",
+                        self._name, key)
+            d = run_in_background(callback, *args, **kwargs)
+            result = self.set(key, d)
+        elif not isinstance(result, defer.Deferred) or result.called:
+            logger.info("[%s]: using completed cached result for [%s]",
+                        self._name, key)
+        else:
+            logger.info("[%s]: using incomplete cached result for [%s]",
+                        self._name, key)
+        return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index b72bb0ff02..941d873ab8 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,20 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches import register_cache
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
 
 
 from blist import sorteddict
 import logging
-import os
 
 
 logger = logging.getLogger(__name__)
 
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
 class StreamChangeCache(object):
     """Keeps track of the stream positions of the latest change in a set of entities.
 
@@ -50,7 +46,7 @@ class StreamChangeCache(object):
     def has_entity_changed(self, entity, stream_pos):
         """Returns True if the entity may have been updated since stream_pos
         """
-        assert type(stream_pos) is int
+        assert type(stream_pos) is int or type(stream_pos) is long
 
         if stream_pos < self._earliest_known_stream_pos:
             self.metrics.inc_misses()
@@ -89,6 +85,21 @@ class StreamChangeCache(object):
 
         return result
 
+    def has_any_entity_changed(self, stream_pos):
+        """Returns if any entity has changed
+        """
+        assert type(stream_pos) is int
+
+        if stream_pos >= self._earliest_known_stream_pos:
+            self.metrics.inc_hits()
+            keys = self._cache.keys()
+            i = keys.bisect_right(stream_pos)
+
+            return i < len(keys)
+        else:
+            self.metrics.inc_misses()
+            return True
+
     def get_all_entities_changed(self, stream_pos):
         """Returns all entites that have had new things since the given
         position. If the position is too old it will return None.