summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/async.py145
-rw-r--r--synapse/util/caches/descriptors.py131
-rw-r--r--synapse/util/caches/expiringcache.py6
-rw-r--r--synapse/util/caches/stream_change_cache.py4
-rw-r--r--synapse/util/distributor.py48
-rw-r--r--synapse/util/logcontext.py11
-rw-r--r--synapse/util/metrics.py19
7 files changed, 171 insertions, 193 deletions
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 5d0fb39130..a7094e2fb4 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,7 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import collections
 import logging
 from contextlib import contextmanager
 
@@ -156,54 +157,72 @@ def concurrently_execute(func, args, limit):
 
 
 class Linearizer(object):
-    """Linearizes access to resources based on a key. Useful to ensure only one
-    thing is happening at a time on a given resource.
+    """Limits concurrent access to resources based on a key. Useful to ensure
+    only a few things happen at a time on a given resource.
 
     Example:
 
-        with (yield linearizer.queue("test_key")):
+        with (yield limiter.queue("test_key")):
             # do some work.
 
     """
-    def __init__(self, name=None, clock=None):
+    def __init__(self, name=None, max_count=1, clock=None):
+        """
+        Args:
+            max_count(int): The maximum number of concurrent accesses
+        """
         if name is None:
             self.name = id(self)
         else:
             self.name = name
-        self.key_to_defer = {}
 
         if not clock:
             from twisted.internet import reactor
             clock = Clock(reactor)
         self._clock = clock
+        self.max_count = max_count
+
+        # key_to_defer is a map from the key to a 2 element list where
+        # the first element is the number of things executing, and
+        # the second element is an OrderedDict, where the keys are deferreds for the
+        # things blocked from executing.
+        self.key_to_defer = {}
 
     @defer.inlineCallbacks
     def queue(self, key):
-        # If there is already a deferred in the queue, we pull it out so that
-        # we can wait on it later.
-        # Then we replace it with a deferred that we resolve *after* the
-        # context manager has exited.
-        # We only return the context manager after the previous deferred has
-        # resolved.
-        # This all has the net effect of creating a chain of deferreds that
-        # wait for the previous deferred before starting their work.
-        current_defer = self.key_to_defer.get(key)
+        entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
 
-        new_defer = defer.Deferred()
-        self.key_to_defer[key] = new_defer
+        # If the number of things executing is greater than the maximum
+        # then add a deferred to the list of blocked items
+        # When on of the things currently executing finishes it will callback
+        # this item so that it can continue executing.
+        if entry[0] >= self.max_count:
+            new_defer = defer.Deferred()
+            entry[1][new_defer] = 1
 
-        if current_defer:
             logger.info(
-                "Waiting to acquire linearizer lock %r for key %r", self.name, key
+                "Waiting to acquire linearizer lock %r for key %r", self.name, key,
             )
             try:
-                with PreserveLoggingContext():
-                    yield current_defer
-            except Exception:
-                logger.exception("Unexpected exception in Linearizer")
-
-            logger.info("Acquired linearizer lock %r for key %r", self.name,
-                        key)
+                yield make_deferred_yieldable(new_defer)
+            except Exception as e:
+                if isinstance(e, CancelledError):
+                    logger.info(
+                        "Cancelling wait for linearizer lock %r for key %r",
+                        self.name, key,
+                    )
+                else:
+                    logger.warn(
+                        "Unexpected exception waiting for linearizer lock %r for key %r",
+                        self.name, key,
+                    )
+
+                # we just have to take ourselves back out of the queue.
+                del entry[1][new_defer]
+                raise
+
+            logger.info("Acquired linearizer lock %r for key %r", self.name, key)
+            entry[0] += 1
 
             # if the code holding the lock completes synchronously, then it
             # will recursively run the next claimant on the list. That can
@@ -213,15 +232,15 @@ class Linearizer(object):
             # In order to break the cycle, we add a cheeky sleep(0) here to
             # ensure that we fall back to the reactor between each iteration.
             #
-            # (There's no particular need for it to happen before we return
-            # the context manager, but it needs to happen while we hold the
-            # lock, and the context manager's exit code must be synchronous,
-            # so actually this is the only sensible place.
+            # (This needs to happen while we hold the lock, and the context manager's exit
+            # code must be synchronous, so this is the only sensible place.)
             yield self._clock.sleep(0)
 
         else:
-            logger.info("Acquired uncontended linearizer lock %r for key %r",
-                        self.name, key)
+            logger.info(
+                "Acquired uncontended linearizer lock %r for key %r", self.name, key,
+            )
+            entry[0] += 1
 
         @contextmanager
         def _ctx_manager():
@@ -229,73 +248,15 @@ class Linearizer(object):
                 yield
             finally:
                 logger.info("Releasing linearizer lock %r for key %r", self.name, key)
-                with PreserveLoggingContext():
-                    new_defer.callback(None)
-                current_d = self.key_to_defer.get(key)
-                if current_d is new_defer:
-                    self.key_to_defer.pop(key, None)
-
-        defer.returnValue(_ctx_manager())
-
-
-class Limiter(object):
-    """Limits concurrent access to resources based on a key. Useful to ensure
-    only a few thing happen at a time on a given resource.
-
-    Example:
-
-        with (yield limiter.queue("test_key")):
-            # do some work.
-
-    """
-    def __init__(self, max_count):
-        """
-        Args:
-            max_count(int): The maximum number of concurrent access
-        """
-        self.max_count = max_count
-
-        # key_to_defer is a map from the key to a 2 element list where
-        # the first element is the number of things executing
-        # the second element is a list of deferreds for the things blocked from
-        # executing.
-        self.key_to_defer = {}
-
-    @defer.inlineCallbacks
-    def queue(self, key):
-        entry = self.key_to_defer.setdefault(key, [0, []])
-
-        # If the number of things executing is greater than the maximum
-        # then add a deferred to the list of blocked items
-        # When on of the things currently executing finishes it will callback
-        # this item so that it can continue executing.
-        if entry[0] >= self.max_count:
-            new_defer = defer.Deferred()
-            entry[1].append(new_defer)
-
-            logger.info("Waiting to acquire limiter lock for key %r", key)
-            with PreserveLoggingContext():
-                yield new_defer
-            logger.info("Acquired limiter lock for key %r", key)
-        else:
-            logger.info("Acquired uncontended limiter lock for key %r", key)
-
-        entry[0] += 1
-
-        @contextmanager
-        def _ctx_manager():
-            try:
-                yield
-            finally:
-                logger.info("Releasing limiter lock for key %r", key)
 
                 # We've finished executing so check if there are any things
                 # blocked waiting to execute and start one of them
                 entry[0] -= 1
 
                 if entry[1]:
-                    next_def = entry[1].pop(0)
+                    (next_def, _) = entry[1].popitem(last=False)
 
+                    # we need to run the next thing in the sentinel context.
                     with PreserveLoggingContext():
                         next_def.callback(None)
                 elif entry[0] == 0:
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f8a07df6b8..861c24809c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
-            # If we're passed a cache_context then we'll want to call its invalidate()
-            # whenever we are invalidated
+            # If we're passed a cache_context then we'll want to call its
+            # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            # cached is a dict arg -> deferred, where deferred results in a
-            # 2-tuple (`arg`, `result`)
             results = {}
-            cached_defers = {}
-            missing = []
+
+            def update_results_dict(res, arg):
+                results[arg] = res
+
+            # list of deferreds to wait for
+            cached_defers = []
+
+            missing = set()
 
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
-                def cache_get(arg):
-                    return cache.get(arg, callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    return arg
             else:
-                key = list(keyargs)
+                keylist = list(keyargs)
 
-                def cache_get(arg):
-                    key[self.list_pos] = arg
-                    return cache.get(tuple(key), callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    keylist[self.list_pos] = arg
+                    return tuple(keylist)
 
             for arg in list_args:
                 try:
-                    res = cache_get(arg)
-
+                    res = cache.get(arg_to_cache_key(arg),
+                                    callback=invalidate_callback)
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
                         res = res.observe()
-                        res.addCallback(lambda r, arg: (arg, r), arg)
-                        cached_defers[arg] = res
+                        res.addCallback(update_results_dict, arg)
+                        cached_defers.append(res)
                     else:
                         results[arg] = res.get_result()
                 except KeyError:
-                    missing.append(arg)
+                    missing.add(arg)
 
             if missing:
+                # we need an observable deferred for each entry in the list,
+                # which we put in the cache. Each deferred resolves with the
+                # relevant result for that key.
+                deferreds_map = {}
+                for arg in missing:
+                    deferred = defer.Deferred()
+                    deferreds_map[arg] = deferred
+                    key = arg_to_cache_key(arg)
+                    observable = ObservableDeferred(deferred)
+                    cache.set(key, observable, callback=invalidate_callback)
+
+                def complete_all(res):
+                    # the wrapped function has completed. It returns a
+                    # a dict. We can now resolve the observable deferreds in
+                    # the cache and update our own result map.
+                    for e in missing:
+                        val = res.get(e, None)
+                        deferreds_map[e].callback(val)
+                        results[e] = val
+
+                def errback(f):
+                    # the wrapped function has failed. Invalidate any cache
+                    # entries we're supposed to be populating, and fail
+                    # their deferreds.
+                    for e in missing:
+                        key = arg_to_cache_key(e)
+                        cache.invalidate(key)
+                        deferreds_map[e].errback(f)
+
+                    # return the failure, to propagate to our caller.
+                    return f
+
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = list(missing)
 
-                ret_d = defer.maybeDeferred(
+                cached_defers.append(defer.maybeDeferred(
                     logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
-                )
-
-                ret_d = ObservableDeferred(ret_d)
-
-                # We need to create deferreds for each arg in the list so that
-                # we can insert the new deferred into the cache.
-                for arg in missing:
-                    observer = ret_d.observe()
-                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)
-
-                    observer = ObservableDeferred(observer)
-
-                    if num_args == 1:
-                        cache.set(
-                            arg, observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, arg)
-                    else:
-                        key = list(keyargs)
-                        key[self.list_pos] = arg
-                        cache.set(
-                            tuple(key), observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, tuple(key))
-
-                    res = observer.observe()
-                    res.addCallback(lambda r, arg: (arg, r), arg)
-
-                    cached_defers[arg] = res
+                ).addCallbacks(complete_all, errback))
 
             if cached_defers:
-                def update_results_dict(res):
-                    results.update(res)
-                    return results
-
-                return logcontext.make_deferred_yieldable(defer.gatherResults(
-                    list(cached_defers.values()),
+                d = defer.gatherResults(
+                    cached_defers,
                     consumeErrors=True,
-                ).addCallback(update_results_dict).addErrback(
+                ).addCallbacks(
+                    lambda _: results,
                     unwrapFirstError
-                ))
+                )
+                return logcontext.make_deferred_yieldable(d)
             else:
                 return results
 
@@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
     cache.
 
     Args:
-        cache (Cache): The underlying cache to use.
+        cached_method_name (str): The name of the single-item lookup method.
+            This is only used to find the cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
         num_args (int): Number of arguments to use as the key in the cache
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 4abca91f6d..ce85b2ae11 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -16,6 +16,7 @@
 import logging
 from collections import OrderedDict
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
@@ -63,7 +64,10 @@ class ExpiringCache(object):
             return
 
         def f():
-            self._prune_cache()
+            return run_as_background_process(
+                "prune_cache_%s" % self._cache_name,
+                self._prune_cache,
+            )
 
         self._clock.looping_call(f, self._expiry_ms / 2)
 
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 258655349b..f2bde74dc5 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -74,12 +74,14 @@ class StreamChangeCache(object):
         assert type(stream_pos) is int
 
         if stream_pos >= self._earliest_known_stream_pos:
-            result = {
+            changed_entities = {
                 self._cache[k] for k in self._cache.islice(
                     start=self._cache.bisect_right(stream_pos),
                 )
             }
 
+            result = changed_entities.intersection(entities)
+
             self.metrics.inc_hits()
         else:
             result = set(entities)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 734331caaa..194da87639 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -17,20 +17,18 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 
 logger = logging.getLogger(__name__)
 
 
 def user_left_room(distributor, user, room_id):
-    with PreserveLoggingContext():
-        distributor.fire("user_left_room", user=user, room_id=room_id)
+    distributor.fire("user_left_room", user=user, room_id=room_id)
 
 
 def user_joined_room(distributor, user, room_id):
-    with PreserveLoggingContext():
-        distributor.fire("user_joined_room", user=user, room_id=room_id)
+    distributor.fire("user_joined_room", user=user, room_id=room_id)
 
 
 class Distributor(object):
@@ -44,9 +42,7 @@ class Distributor(object):
       model will do for today.
     """
 
-    def __init__(self, suppress_failures=True):
-        self.suppress_failures = suppress_failures
-
+    def __init__(self):
         self.signals = {}
         self.pre_registration = {}
 
@@ -56,7 +52,6 @@ class Distributor(object):
 
         self.signals[name] = Signal(
             name,
-            suppress_failures=self.suppress_failures,
         )
 
         if name in self.pre_registration:
@@ -75,10 +70,18 @@ class Distributor(object):
             self.pre_registration[name].append(observer)
 
     def fire(self, name, *args, **kwargs):
+        """Dispatches the given signal to the registered observers.
+
+        Runs the observers as a background process. Does not return a deferred.
+        """
         if name not in self.signals:
             raise KeyError("%r does not have a signal named %s" % (self, name))
 
-        return self.signals[name].fire(*args, **kwargs)
+        run_as_background_process(
+            name,
+            self.signals[name].fire,
+            *args, **kwargs
+        )
 
 
 class Signal(object):
@@ -91,9 +94,8 @@ class Signal(object):
     method into all of the observers.
     """
 
-    def __init__(self, name, suppress_failures):
+    def __init__(self, name):
         self.name = name
-        self.suppress_failures = suppress_failures
         self.observers = []
 
     def observe(self, observer):
@@ -103,7 +105,6 @@ class Signal(object):
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    @defer.inlineCallbacks
     def fire(self, *args, **kwargs):
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -121,22 +122,17 @@ class Signal(object):
                         failure.type,
                         failure.value,
                         failure.getTracebackObject()))
-                if not self.suppress_failures:
-                    return failure
 
             return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
 
-        with PreserveLoggingContext():
-            deferreds = [
-                do(observer)
-                for observer in self.observers
-            ]
-
-            res = yield defer.gatherResults(
-                deferreds, consumeErrors=True
-            ).addErrback(unwrapFirstError)
+        deferreds = [
+            run_in_background(do, o)
+            for o in self.observers
+        ]
 
-        defer.returnValue(res)
+        return make_deferred_yieldable(defer.gatherResults(
+            deferreds, consumeErrors=True,
+        ))
 
     def __repr__(self):
         return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index f6c7175f74..8dcae50b39 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -99,6 +99,17 @@ class ContextResourceUsage(object):
         self.db_sched_duration_sec = 0
         self.evt_db_fetch_count = 0
 
+    def __repr__(self):
+        return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
+                "db_txn_count='%r', db_txn_duration_sec='%r', "
+                "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
+                    self.ru_stime,
+                    self.ru_utime,
+                    self.db_txn_count,
+                    self.db_txn_duration_sec,
+                    self.db_sched_duration_sec,
+                    self.evt_db_fetch_count,)
+
     def __iadd__(self, other):
         """Add another ContextResourceUsage's stats to this one's.
 
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 6ba7107896..97f1267380 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -104,12 +104,19 @@ class Measure(object):
             logger.warn("Expected context. (%r)", self.name)
             return
 
-        usage = context.get_resource_usage() - self.start_usage
-        block_ru_utime.labels(self.name).inc(usage.ru_utime)
-        block_ru_stime.labels(self.name).inc(usage.ru_stime)
-        block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
-        block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
-        block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
+        current = context.get_resource_usage()
+        usage = current - self.start_usage
+        try:
+            block_ru_utime.labels(self.name).inc(usage.ru_utime)
+            block_ru_stime.labels(self.name).inc(usage.ru_stime)
+            block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
+            block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
+            block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
+        except ValueError:
+            logger.warn(
+                "Failed to save metrics! OLD: %r, NEW: %r",
+                self.start_usage, current
+            )
 
         if self.created_context:
             self.start_context.__exit__(exc_type, exc_val, exc_tb)