summary refs log tree commit diff
path: root/synapse/util/async.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util/async.py')
-rw-r--r--synapse/util/async.py270
1 files changed, 163 insertions, 107 deletions
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 1453faf0ef..a7094e2fb4 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,38 +13,27 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import collections
+import logging
+from contextlib import contextmanager
+
+from six.moves import range
 
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
 
-from twisted.internet import defer, reactor
+from synapse.util import Clock, logcontext, unwrapFirstError
 
 from .logcontext import (
-    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+    PreserveLoggingContext,
+    make_deferred_yieldable,
+    run_in_background,
 )
-from synapse.util import unwrapFirstError
-
-from contextlib import contextmanager
-
-import logging
 
 logger = logging.getLogger(__name__)
 
 
-@defer.inlineCallbacks
-def sleep(seconds):
-    d = defer.Deferred()
-    with PreserveLoggingContext():
-        reactor.callLater(seconds, d.callback, seconds)
-        res = yield d
-    defer.returnValue(res)
-
-
-def run_on_reactor():
-    """ This will cause the rest of the function to be invoked upon the next
-    iteration of the main loop
-    """
-    return sleep(0)
-
-
 class ObservableDeferred(object):
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
@@ -53,6 +43,11 @@ class ObservableDeferred(object):
 
     Cancelling or otherwise resolving an observer will not affect the original
     ObservableDeferred.
+
+    NB that it does not attempt to do anything with logcontexts; in general
+    you should probably make_deferred_yieldable the deferreds
+    returned by `observe`, and ensure that the original deferred runs its
+    callbacks in the sentinel logcontext.
     """
 
     __slots__ = ["_deferred", "_observers", "_result"]
@@ -68,7 +63,7 @@ class ObservableDeferred(object):
                 try:
                     # TODO: Handle errors here.
                     self._observers.pop().callback(r)
-                except:
+                except Exception:
                     pass
             return r
 
@@ -78,7 +73,7 @@ class ObservableDeferred(object):
                 try:
                     # TODO: Handle errors here.
                     self._observers.pop().errback(f)
-                except:
+                except Exception:
                     pass
 
             if consumeErrors:
@@ -151,77 +146,19 @@ def concurrently_execute(func, args, limit):
     def _concurrently_execute_inner():
         try:
             while True:
-                yield func(it.next())
+                yield func(next(it))
         except StopIteration:
             pass
 
-    return preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(_concurrently_execute_inner)()
-        for _ in xrange(limit)
+    return logcontext.make_deferred_yieldable(defer.gatherResults([
+        run_in_background(_concurrently_execute_inner)
+        for _ in range(limit)
     ], consumeErrors=True)).addErrback(unwrapFirstError)
 
 
 class Linearizer(object):
-    """Linearizes access to resources based on a key. Useful to ensure only one
-    thing is happening at a time on a given resource.
-
-    Example:
-
-        with (yield linearizer.queue("test_key")):
-            # do some work.
-
-    """
-    def __init__(self, name=None):
-        if name is None:
-            self.name = id(self)
-        else:
-            self.name = name
-        self.key_to_defer = {}
-
-    @defer.inlineCallbacks
-    def queue(self, key):
-        # If there is already a deferred in the queue, we pull it out so that
-        # we can wait on it later.
-        # Then we replace it with a deferred that we resolve *after* the
-        # context manager has exited.
-        # We only return the context manager after the previous deferred has
-        # resolved.
-        # This all has the net effect of creating a chain of deferreds that
-        # wait for the previous deferred before starting their work.
-        current_defer = self.key_to_defer.get(key)
-
-        new_defer = defer.Deferred()
-        self.key_to_defer[key] = new_defer
-
-        if current_defer:
-            logger.info(
-                "Waiting to acquire linearizer lock %r for key %r", self.name, key
-            )
-            try:
-                with PreserveLoggingContext():
-                    yield current_defer
-            except:
-                logger.exception("Unexpected exception in Linearizer")
-
-        logger.info("Acquired linearizer lock %r for key %r", self.name, key)
-
-        @contextmanager
-        def _ctx_manager():
-            try:
-                yield
-            finally:
-                logger.info("Releasing linearizer lock %r for key %r", self.name, key)
-                new_defer.callback(None)
-                current_d = self.key_to_defer.get(key)
-                if current_d is new_defer:
-                    self.key_to_defer.pop(key, None)
-
-        defer.returnValue(_ctx_manager())
-
-
-class Limiter(object):
     """Limits concurrent access to resources based on a key. Useful to ensure
-    only a few thing happen at a time on a given resource.
+    only a few things happen at a time on a given resource.
 
     Example:
 
@@ -229,22 +166,31 @@ class Limiter(object):
             # do some work.
 
     """
-    def __init__(self, max_count):
+    def __init__(self, name=None, max_count=1, clock=None):
         """
         Args:
-            max_count(int): The maximum number of concurrent access
+            max_count(int): The maximum number of concurrent accesses
         """
+        if name is None:
+            self.name = id(self)
+        else:
+            self.name = name
+
+        if not clock:
+            from twisted.internet import reactor
+            clock = Clock(reactor)
+        self._clock = clock
         self.max_count = max_count
 
         # key_to_defer is a map from the key to a 2 element list where
-        # the first element is the number of things executing
-        # the second element is a list of deferreds for the things blocked from
-        # executing.
+        # the first element is the number of things executing, and
+        # the second element is an OrderedDict, where the keys are deferreds for the
+        # things blocked from executing.
         self.key_to_defer = {}
 
     @defer.inlineCallbacks
     def queue(self, key):
-        entry = self.key_to_defer.setdefault(key, [0, []])
+        entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
 
         # If the number of things executing is greater than the maximum
         # then add a deferred to the list of blocked items
@@ -252,27 +198,71 @@ class Limiter(object):
         # this item so that it can continue executing.
         if entry[0] >= self.max_count:
             new_defer = defer.Deferred()
-            entry[1].append(new_defer)
-            with PreserveLoggingContext():
-                yield new_defer
+            entry[1][new_defer] = 1
+
+            logger.info(
+                "Waiting to acquire linearizer lock %r for key %r", self.name, key,
+            )
+            try:
+                yield make_deferred_yieldable(new_defer)
+            except Exception as e:
+                if isinstance(e, CancelledError):
+                    logger.info(
+                        "Cancelling wait for linearizer lock %r for key %r",
+                        self.name, key,
+                    )
+                else:
+                    logger.warn(
+                        "Unexpected exception waiting for linearizer lock %r for key %r",
+                        self.name, key,
+                    )
+
+                # we just have to take ourselves back out of the queue.
+                del entry[1][new_defer]
+                raise
+
+            logger.info("Acquired linearizer lock %r for key %r", self.name, key)
+            entry[0] += 1
+
+            # if the code holding the lock completes synchronously, then it
+            # will recursively run the next claimant on the list. That can
+            # relatively rapidly lead to stack exhaustion. This is essentially
+            # the same problem as http://twistedmatrix.com/trac/ticket/9304.
+            #
+            # In order to break the cycle, we add a cheeky sleep(0) here to
+            # ensure that we fall back to the reactor between each iteration.
+            #
+            # (This needs to happen while we hold the lock, and the context manager's exit
+            # code must be synchronous, so this is the only sensible place.)
+            yield self._clock.sleep(0)
 
-        entry[0] += 1
+        else:
+            logger.info(
+                "Acquired uncontended linearizer lock %r for key %r", self.name, key,
+            )
+            entry[0] += 1
 
         @contextmanager
         def _ctx_manager():
             try:
                 yield
             finally:
+                logger.info("Releasing linearizer lock %r for key %r", self.name, key)
+
                 # We've finished executing so check if there are any things
                 # blocked waiting to execute and start one of them
                 entry[0] -= 1
-                try:
-                    entry[1].pop(0).callback(None)
-                except IndexError:
-                    # If nothing else is executing for this key then remove it
-                    # from the map
-                    if entry[0] == 0:
-                        self.key_to_defer.pop(key, None)
+
+                if entry[1]:
+                    (next_def, _) = entry[1].popitem(last=False)
+
+                    # we need to run the next thing in the sentinel context.
+                    with PreserveLoggingContext():
+                        next_def.callback(None)
+                elif entry[0] == 0:
+                    # We were the last thing for this key: remove it from the
+                    # map.
+                    del self.key_to_defer[key]
 
         defer.returnValue(_ctx_manager())
 
@@ -316,7 +306,7 @@ class ReadWriteLock(object):
 
         # We wait for the latest writer to finish writing. We can safely ignore
         # any existing readers... as they're readers.
-        yield curr_writer
+        yield make_deferred_yieldable(curr_writer)
 
         @contextmanager
         def _ctx_manager():
@@ -345,7 +335,7 @@ class ReadWriteLock(object):
         curr_readers.clear()
         self.key_to_current_writer[key] = new_defer
 
-        yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
+        yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
 
         @contextmanager
         def _ctx_manager():
@@ -357,3 +347,69 @@ class ReadWriteLock(object):
                     self.key_to_current_writer.pop(key)
 
         defer.returnValue(_ctx_manager())
+
+
+class DeferredTimeoutError(Exception):
+    """
+    This error is raised by default when a L{Deferred} times out.
+    """
+
+
+def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
+    """
+    Add a timeout to a deferred by scheduling it to be cancelled after
+    timeout seconds.
+
+    This is essentially a backport of deferred.addTimeout, which was introduced
+    in twisted 16.5.
+
+    If the deferred gets timed out, it errbacks with a DeferredTimeoutError,
+    unless a cancelable function was passed to its initialization or unless
+    a different on_timeout_cancel callable is provided.
+
+    Args:
+        deferred (defer.Deferred): deferred to be timed out
+        timeout (Number): seconds to time out after
+        reactor (twisted.internet.reactor): the Twisted reactor to use
+
+        on_timeout_cancel (callable): A callable which is called immediately
+            after the deferred times out, and not if this deferred is
+            otherwise cancelled before the timeout.
+
+            It takes an arbitrary value, which is the value of the deferred at
+            that exact point in time (probably a CancelledError Failure), and
+            the timeout.
+
+            The default callable (if none is provided) will translate a
+            CancelledError Failure into a DeferredTimeoutError.
+    """
+    timed_out = [False]
+
+    def time_it_out():
+        timed_out[0] = True
+        deferred.cancel()
+
+    delayed_call = reactor.callLater(timeout, time_it_out)
+
+    def convert_cancelled(value):
+        if timed_out[0]:
+            to_call = on_timeout_cancel or _cancelled_to_timed_out_error
+            return to_call(value, timeout)
+        return value
+
+    deferred.addBoth(convert_cancelled)
+
+    def cancel_timeout(result):
+        # stop the pending call to cancel the deferred if it's been fired
+        if delayed_call.active():
+            delayed_call.cancel()
+        return result
+
+    deferred.addBoth(cancel_timeout)
+
+
+def _cancelled_to_timed_out_error(value, timeout):
+    if isinstance(value, failure.Failure):
+        value.trap(CancelledError)
+        raise DeferredTimeoutError(timeout, "Deferred")
+    return value