summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py10
-rw-r--r--synapse/util/async.py7
-rw-r--r--synapse/util/caches/descriptors.py205
-rw-r--r--synapse/util/caches/stream_change_cache.py2
-rw-r--r--synapse/util/logcontext.py77
-rw-r--r--synapse/util/retryutils.py34
6 files changed, 227 insertions, 108 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 30fc480108..98a5a26ac5 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 
 class DeferredTimedOutError(SynapseError):
     def __init__(self):
-        super(SynapseError).__init__(504, "Timed out")
+        super(SynapseError, self).__init__(504, "Timed out")
 
 
 def unwrapFirstError(failure):
@@ -93,8 +93,10 @@ class Clock(object):
         ret_deferred = defer.Deferred()
 
         def timed_out_fn():
+            e = DeferredTimedOutError()
+
             try:
-                ret_deferred.errback(DeferredTimedOutError())
+                ret_deferred.errback(e)
             except:
                 pass
 
@@ -114,7 +116,7 @@ class Clock(object):
 
         ret_deferred.addBoth(cancel)
 
-        def sucess(res):
+        def success(res):
             try:
                 ret_deferred.callback(res)
             except:
@@ -128,7 +130,7 @@ class Clock(object):
             except:
                 pass
 
-        given_deferred.addCallbacks(callback=sucess, errback=err)
+        given_deferred.addCallbacks(callback=success, errback=err)
 
         timer = self.call_later(time_out, timed_out_fn)
 
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 35380bf8ed..1453faf0ef 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -89,6 +89,11 @@ class ObservableDeferred(object):
         deferred.addCallbacks(callback, errback)
 
     def observe(self):
+        """Observe the underlying deferred.
+
+        Can return either a deferred if the underlying deferred is still pending
+        (or has failed), or the actual value. Callers may need to use maybeDeferred.
+        """
         if not self._result:
             d = defer.Deferred()
 
@@ -101,7 +106,7 @@ class ObservableDeferred(object):
             return d
         else:
             success, res = self._result
-            return defer.succeed(res) if success else defer.fail(res)
+            return res if success else defer.fail(res)
 
     def observers(self):
         return self._observers
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 998de70d29..9d0d0be1f9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -15,12 +15,9 @@
 import logging
 
 from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
 
 from . import DEBUG_CACHES, register_cache
 
@@ -189,7 +186,67 @@ class Cache(object):
         self.cache.clear()
 
 
-class CacheDescriptor(object):
+class _CacheDescriptorBase(object):
+    def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+        self.orig = orig
+
+        if inlineCallbacks:
+            self.function_to_call = defer.inlineCallbacks(orig)
+        else:
+            self.function_to_call = orig
+
+        arg_spec = inspect.getargspec(orig)
+        all_args = arg_spec.args
+
+        if "cache_context" in all_args:
+            if not cache_context:
+                raise ValueError(
+                    "Cannot have a 'cache_context' arg without setting"
+                    " cache_context=True"
+                )
+        elif cache_context:
+            raise ValueError(
+                "Cannot have cache_context=True without having an arg"
+                " named `cache_context`"
+            )
+
+        if num_args is None:
+            num_args = len(all_args) - 1
+            if cache_context:
+                num_args -= 1
+
+        if len(all_args) < num_args + 1:
+            raise Exception(
+                "Not enough explicit positional arguments to key off for %r: "
+                "got %i args, but wanted %i. (@cached cannot key off *args or "
+                "**kwargs)"
+                % (orig.__name__, len(all_args), num_args)
+            )
+
+        self.num_args = num_args
+
+        # list of the names of the args used as the cache key
+        self.arg_names = all_args[1:num_args + 1]
+
+        # self.arg_defaults is a map of arg name to its default value for each
+        # argument that has a default value
+        if arg_spec.defaults:
+            self.arg_defaults = dict(zip(
+                all_args[-len(arg_spec.defaults):],
+                arg_spec.defaults
+            ))
+        else:
+            self.arg_defaults = {}
+
+        if "cache_context" in self.arg_names:
+            raise Exception(
+                "cache_context arg cannot be included among the cache keys"
+            )
+
+        self.add_cache_context = cache_context
+
+
+class CacheDescriptor(_CacheDescriptorBase):
     """ A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
@@ -217,52 +274,24 @@ class CacheDescriptor(object):
             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
             defer.returnValue(r1 + r2)
 
+    Args:
+        num_args (int): number of positional arguments (excluding ``self`` and
+            ``cache_context``) to use as cache keys. Defaults to all named
+            args of the function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+    def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
                  inlineCallbacks=False, cache_context=False, iterable=False):
-        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
-        self.orig = orig
+        super(CacheDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
+            cache_context=cache_context)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
+        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.max_entries = max_entries
-        self.num_args = num_args
         self.tree = tree
-
         self.iterable = iterable
 
-        all_args = inspect.getargspec(orig)
-        self.arg_names = all_args.args[1:num_args + 1]
-
-        if "cache_context" in all_args.args:
-            if not cache_context:
-                raise ValueError(
-                    "Cannot have a 'cache_context' arg without setting"
-                    " cache_context=True"
-                )
-            try:
-                self.arg_names.remove("cache_context")
-            except ValueError:
-                pass
-        elif cache_context:
-            raise ValueError(
-                "Cannot have cache_context=True without having an arg"
-                " named `cache_context`"
-            )
-
-        self.add_cache_context = cache_context
-
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwargs)"
-                % (orig.__name__,)
-            )
-
     def __get__(self, obj, objtype=None):
         cache = Cache(
             name=self.orig.__name__,
@@ -272,18 +301,31 @@ class CacheDescriptor(object):
             iterable=self.iterable,
         )
 
+        def get_cache_key(args, kwargs):
+            """Given some args/kwargs return a generator that resolves into
+            the cache_key.
+
+            We loop through each arg name, looking up if its in the `kwargs`,
+            otherwise using the next argument in `args`. If there are no more
+            args then we try looking the arg name up in the defaults
+            """
+            pos = 0
+            for nm in self.arg_names:
+                if nm in kwargs:
+                    yield kwargs[nm]
+                elif pos < len(args):
+                    yield args[pos]
+                    pos += 1
+                else:
+                    yield self.arg_defaults[nm]
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add temp cache_context so inspect.getcallargs doesn't explode
-            if self.add_cache_context:
-                kwargs["cache_context"] = None
-
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            cache_key = tuple(get_cache_key(args, kwargs))
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -308,11 +350,9 @@ class CacheDescriptor(object):
                         defer.returnValue(cached_result)
                     observer.addCallback(check_result)
 
-                return preserve_context_over_deferred(observer)
             except KeyError:
                 ret = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     obj, *args, **kwargs
                 )
 
@@ -322,10 +362,14 @@ class CacheDescriptor(object):
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.set(cache_key, ret, callback=invalidate_callback)
+                result_d = ObservableDeferred(ret, consumeErrors=True)
+                cache.set(cache_key, result_d, callback=invalidate_callback)
+                observer = result_d.observe()
 
-                return preserve_context_over_deferred(ret.observe())
+            if isinstance(observer, defer.Deferred):
+                return logcontext.make_deferred_yieldable(observer)
+            else:
+                return observer
 
         wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
@@ -338,48 +382,40 @@ class CacheDescriptor(object):
         return wrapped
 
 
-class CacheListDescriptor(object):
+class CacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
-    the list of missing keys to the wrapped fucntion.
+    the list of missing keys to the wrapped function.
+
+    Once wrapped, the function returns either a Deferred which resolves to
+    the list of results, or (if all results were cached), just the list of
+    results.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=1,
+    def __init__(self, orig, cached_method_name, list_name, num_args=None,
                  inlineCallbacks=False):
         """
         Args:
             orig (function)
-            method_name (str); The name of the chached method.
+            cached_method_name (str): The name of the chached method.
             list_name (str): Name of the argument which is the bulk lookup list
-            num_args (int)
+            num_args (int): number of positional arguments (excluding ``self``,
+                but including list_name) to use as cache keys. Defaults to all
+                named args of the function.
             inlineCallbacks (bool): Whether orig is a generator that should
                 be wrapped by defer.inlineCallbacks
         """
-        self.orig = orig
+        super(CacheListDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
-        self.num_args = num_args
         self.list_name = list_name
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
-
         self.cached_method_name = cached_method_name
 
         self.sentinel = object()
 
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
-                % (orig.__name__,)
-            )
-
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
@@ -425,8 +461,7 @@ class CacheListDescriptor(object):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
                 )
 
@@ -435,8 +470,7 @@ class CacheListDescriptor(object):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    with PreserveLoggingContext():
-                        observer = ret_d.observe()
+                    observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
@@ -463,7 +497,7 @@ class CacheListDescriptor(object):
                     results.update(res)
                     return results
 
-                return preserve_context_over_deferred(defer.gatherResults(
+                return logcontext.make_deferred_yieldable(defer.gatherResults(
                     cached_defers.values(),
                     consumeErrors=True,
                 ).addCallback(update_results_dict).addErrback(
@@ -487,7 +521,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
            iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
@@ -499,8 +533,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
-                          iterable=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
+                          cache_context=False, iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -512,7 +546,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
     )
 
 
-def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. A single argument
@@ -525,7 +559,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
         cache (Cache): The underlying cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
-        num_args (int): Number of arguments to use as the key in the cache.
+        num_args (int): Number of arguments to use as the key in the cache
+            (including list_name). Defaults to all named parameters.
         inlineCallbacks (bool): Should the function be wrapped in an
             `defer.inlineCallbacks`?
 
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index b72bb0ff02..70fe00ce0b 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -50,7 +50,7 @@ class StreamChangeCache(object):
     def has_entity_changed(self, entity, stream_pos):
         """Returns True if the entity may have been updated since stream_pos
         """
-        assert type(stream_pos) is int
+        assert type(stream_pos) is int or type(stream_pos) is long
 
         if stream_pos < self._earliest_known_stream_pos:
             self.metrics.inc_misses()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 6c83eb213d..857afee7cb 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+""" Thread-local-alike tracking of log contexts within synapse
+
+This module provides objects and utilities for tracking contexts through
+synapse code, so that log lines can include a request identifier, and so that
+CPU and database activity can be accounted for against the request that caused
+them.
+
+See doc/log_contexts.rst for details on how this works.
+"""
+
 from twisted.internet import defer
 
 import threading
@@ -300,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs):
 def preserve_context_over_deferred(deferred, context=None):
     """Given a deferred wrap it such that any callbacks added later to it will
     be invoked with the current context.
+
+    Deprecated: this almost certainly doesn't do want you want, ie make
+    the deferred follow the synapse logcontext rules: try
+    ``make_deferred_yieldable`` instead.
     """
     if context is None:
         context = LoggingContext.current_context()
@@ -309,24 +323,65 @@ def preserve_context_over_deferred(deferred, context=None):
 
 
 def preserve_fn(f):
-    """Ensures that function is called with correct context and that context is
-    restored after return. Useful for wrapping functions that return a deferred
-    which you don't yield on.
+    """Wraps a function, to ensure that the current context is restored after
+    return from the function, and that the sentinel context is set once the
+    deferred returned by the funtion completes.
+
+    Useful for wrapping functions that return a deferred which you don't yield
+    on.
     """
+    def reset_context(result):
+        LoggingContext.set_current_context(LoggingContext.sentinel)
+        return result
+
+    # XXX: why is this here rather than inside g? surely we want to preserve
+    # the context from the time the function was called, not when it was
+    # wrapped?
     current = LoggingContext.current_context()
 
     def g(*args, **kwargs):
-        with PreserveLoggingContext(current):
-            res = f(*args, **kwargs)
-            if isinstance(res, defer.Deferred):
-                return preserve_context_over_deferred(
-                    res, context=LoggingContext.sentinel
-                )
-            else:
-                return res
+        res = f(*args, **kwargs)
+        if isinstance(res, defer.Deferred) and not res.called:
+            # The function will have reset the context before returning, so
+            # we need to restore it now.
+            LoggingContext.set_current_context(current)
+
+            # The original context will be restored when the deferred
+            # completes, but there is nothing waiting for it, so it will
+            # get leaked into the reactor or some other function which
+            # wasn't expecting it. We therefore need to reset the context
+            # here.
+            #
+            # (If this feels asymmetric, consider it this way: we are
+            # effectively forking a new thread of execution. We are
+            # probably currently within a ``with LoggingContext()`` block,
+            # which is supposed to have a single entry and exit point. But
+            # by spawning off another deferred, we are effectively
+            # adding a new exit point.)
+            res.addBoth(reset_context)
+        return res
     return g
 
 
+@defer.inlineCallbacks
+def make_deferred_yieldable(deferred):
+    """Given a deferred, make it follow the Synapse logcontext rules:
+
+    If the deferred has completed (or is not actually a Deferred), essentially
+    does nothing (just returns another completed deferred with the
+    result/failure).
+
+    If the deferred has not yet completed, resets the logcontext before
+    returning a deferred. Then, when the deferred completes, restores the
+    current logcontext before running callbacks/errbacks.
+
+    (This is more-or-less the opposite operation to preserve_fn.)
+    """
+    with PreserveLoggingContext():
+        r = yield deferred
+    defer.returnValue(r)
+
+
 # modules to ignore in `logcontext_tracer`
 _to_ignore = [
     "synapse.util.logcontext",
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 153ef001ad..4fa9d1a03c 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import synapse.util.logcontext
 from twisted.internet import defer
 
 from synapse.api.errors import CodeMessageException
@@ -35,7 +35,8 @@ class NotRetryingDestination(Exception):
 
 
 @defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False,
+                      **kwargs):
     """For a given destination check if we have previously failed to
     send a request there and are waiting before retrying the destination.
     If we are not ready to retry the destination, this will raise a
@@ -43,6 +44,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
     that will mark the destination as down if an exception is thrown (excluding
     CodeMessageException with code < 500)
 
+    Args:
+        destination (str): name of homeserver
+        clock (synapse.util.clock): timing source
+        store (synapse.storage.transactions.TransactionStore): datastore
+        ignore_backoff (bool): true to ignore the historical backoff data and
+            try the request anyway. We will still update the next
+            retry_interval on success/failure.
+
     Example usage:
 
         try:
@@ -66,7 +75,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
 
         now = int(clock.time_msec())
 
-        if retry_last_ts + retry_interval > now:
+        if not ignore_backoff and retry_last_ts + retry_interval > now:
             raise NotRetryingDestination(
                 retry_last_ts=retry_last_ts,
                 retry_interval=retry_interval,
@@ -124,7 +133,13 @@ class RetryDestinationLimiter(object):
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         valid_err_code = False
-        if exc_type is not None and issubclass(exc_type, CodeMessageException):
+        if exc_type is None:
+            valid_err_code = True
+        elif not issubclass(exc_type, Exception):
+            # avoid treating exceptions which don't derive from Exception as
+            # failures; this is mostly so as not to catch defer._DefGen.
+            valid_err_code = True
+        elif issubclass(exc_type, CodeMessageException):
             # Some error codes are perfectly fine for some APIs, whereas other
             # APIs may expect to never received e.g. a 404. It's important to
             # handle 404 as some remote servers will return a 404 when the HS
@@ -142,11 +157,13 @@ class RetryDestinationLimiter(object):
             else:
                 valid_err_code = False
 
-        if exc_type is None or valid_err_code:
+        if valid_err_code:
             # We connected successfully.
             if not self.retry_interval:
                 return
 
+            logger.debug("Connection to %s was successful; clearing backoff",
+                         self.destination)
             retry_last_ts = 0
             self.retry_interval = 0
         else:
@@ -160,6 +177,10 @@ class RetryDestinationLimiter(object):
             else:
                 self.retry_interval = self.min_retry_interval
 
+            logger.debug(
+                "Connection to %s was unsuccessful (%s(%s)); backoff now %i",
+                self.destination, exc_type, exc_val, self.retry_interval
+            )
             retry_last_ts = int(self.clock.time_msec())
 
         @defer.inlineCallbacks
@@ -173,4 +194,5 @@ class RetryDestinationLimiter(object):
                     "Failed to store set_destination_retry_timings",
                 )
 
-        store_retry_timings()
+        # we deliberately do this in the background.
+        synapse.util.logcontext.preserve_fn(store_retry_timings)()