summary refs log tree commit diff
path: root/synapse/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/util')
-rw-r--r--synapse/util/__init__.py56
-rw-r--r--synapse/util/async.py77
-rw-r--r--synapse/util/file_consumer.py6
-rw-r--r--synapse/util/httpresourcetree.py7
-rw-r--r--synapse/util/logcontext.py81
-rw-r--r--synapse/util/logformatter.py4
-rw-r--r--synapse/util/ratelimitutils.py19
-rw-r--r--synapse/util/retryutils.py4
-rw-r--r--synapse/util/stringutils.py5
-rw-r--r--synapse/util/wheel_timer.py4
10 files changed, 161 insertions, 102 deletions
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 756d8ffa32..814a7bf71b 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.api.errors import SynapseError
 from synapse.util.logcontext import PreserveLoggingContext
 
 from twisted.internet import defer, reactor, task
@@ -24,11 +23,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class DeferredTimedOutError(SynapseError):
-    def __init__(self):
-        super(DeferredTimedOutError, self).__init__(504, "Timed out")
-
-
 def unwrapFirstError(failure):
     # defer.gatherResults and DeferredLists wrap failures.
     failure.trap(defer.FirstError)
@@ -85,53 +79,3 @@ class Clock(object):
         except Exception:
             if not ignore_errs:
                 raise
-
-    def time_bound_deferred(self, given_deferred, time_out):
-        if given_deferred.called:
-            return given_deferred
-
-        ret_deferred = defer.Deferred()
-
-        def timed_out_fn():
-            e = DeferredTimedOutError()
-
-            try:
-                ret_deferred.errback(e)
-            except Exception:
-                pass
-
-            try:
-                given_deferred.cancel()
-            except Exception:
-                pass
-
-        timer = None
-
-        def cancel(res):
-            try:
-                self.cancel_call_later(timer)
-            except Exception:
-                pass
-            return res
-
-        ret_deferred.addBoth(cancel)
-
-        def success(res):
-            try:
-                ret_deferred.callback(res)
-            except Exception:
-                pass
-
-            return res
-
-        def err(res):
-            try:
-                ret_deferred.errback(res)
-            except Exception:
-                pass
-
-        given_deferred.addCallbacks(callback=success, errback=err)
-
-        timer = self.call_later(time_out, timed_out_fn)
-
-        return ret_deferred
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 0729bb2863..9dd4e6b5bc 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -15,9 +15,11 @@
 
 
 from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
 
 from .logcontext import (
-    PreserveLoggingContext, make_deferred_yieldable, preserve_fn
+    PreserveLoggingContext, make_deferred_yieldable, run_in_background
 )
 from synapse.util import logcontext, unwrapFirstError
 
@@ -25,6 +27,8 @@ from contextlib import contextmanager
 
 import logging
 
+from six.moves import range
+
 logger = logging.getLogger(__name__)
 
 
@@ -156,13 +160,13 @@ def concurrently_execute(func, args, limit):
     def _concurrently_execute_inner():
         try:
             while True:
-                yield func(it.next())
+                yield func(next(it))
         except StopIteration:
             pass
 
     return logcontext.make_deferred_yieldable(defer.gatherResults([
-        preserve_fn(_concurrently_execute_inner)()
-        for _ in xrange(limit)
+        run_in_background(_concurrently_execute_inner)
+        for _ in range(limit)
     ], consumeErrors=True)).addErrback(unwrapFirstError)
 
 
@@ -392,3 +396,68 @@ class ReadWriteLock(object):
                     self.key_to_current_writer.pop(key)
 
         defer.returnValue(_ctx_manager())
+
+
+class DeferredTimeoutError(Exception):
+    """
+    This error is raised by default when a L{Deferred} times out.
+    """
+
+
+def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
+    """
+    Add a timeout to a deferred by scheduling it to be cancelled after
+    timeout seconds.
+
+    This is essentially a backport of deferred.addTimeout, which was introduced
+    in twisted 16.5.
+
+    If the deferred gets timed out, it errbacks with a DeferredTimeoutError,
+    unless a cancelable function was passed to its initialization or unless
+    a different on_timeout_cancel callable is provided.
+
+    Args:
+        deferred (defer.Deferred): deferred to be timed out
+        timeout (Number): seconds to time out after
+
+        on_timeout_cancel (callable): A callable which is called immediately
+            after the deferred times out, and not if this deferred is
+            otherwise cancelled before the timeout.
+
+            It takes an arbitrary value, which is the value of the deferred at
+            that exact point in time (probably a CancelledError Failure), and
+            the timeout.
+
+            The default callable (if none is provided) will translate a
+            CancelledError Failure into a DeferredTimeoutError.
+    """
+    timed_out = [False]
+
+    def time_it_out():
+        timed_out[0] = True
+        deferred.cancel()
+
+    delayed_call = reactor.callLater(timeout, time_it_out)
+
+    def convert_cancelled(value):
+        if timed_out[0]:
+            to_call = on_timeout_cancel or _cancelled_to_timed_out_error
+            return to_call(value, timeout)
+        return value
+
+    deferred.addBoth(convert_cancelled)
+
+    def cancel_timeout(result):
+        # stop the pending call to cancel the deferred if it's been fired
+        if delayed_call.active():
+            delayed_call.cancel()
+        return result
+
+    deferred.addBoth(cancel_timeout)
+
+
+def _cancelled_to_timed_out_error(value, timeout):
+    if isinstance(value, failure.Failure):
+        value.trap(CancelledError)
+        raise DeferredTimeoutError(timeout, "Deferred")
+    return value
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 3c8a165331..3380970e4e 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import threads, reactor
 
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 
 from six.moves import queue
 
@@ -70,7 +70,9 @@ class BackgroundFileConsumer(object):
 
         self._producer = producer
         self.streaming = streaming
-        self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+        self._finished_deferred = run_in_background(
+            threads.deferToThread, self._writer
+        )
         if not streaming:
             self._producer.resumeProducing()
 
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index d747849553..e9f0f292ee 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -40,9 +40,12 @@ def create_resource_tree(desired_tree, root_resource):
     # extra resources to existing nodes. See self._resource_id for the key.
     resource_mappings = {}
     for full_path, res in desired_tree.items():
+        # twisted requires all resources to be bytes
+        full_path = full_path.encode("utf-8")
+
         logger.info("Attaching %s to path %s", res, full_path)
         last_resource = root_resource
-        for path_seg in full_path.split('/')[1:-1]:
+        for path_seg in full_path.split(b'/')[1:-1]:
             if path_seg not in last_resource.listNames():
                 # resource doesn't exist, so make a "dummy resource"
                 child_resource = NoResource()
@@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
 
         # ===========================
         # now attach the actual desired resource
-        last_path_seg = full_path.split('/')[-1]
+        last_path_seg = full_path.split(b'/')[-1]
 
         # if there is already a resource here, thieve its children and
         # replace it
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index d59adc236e..eab9d57650 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -164,7 +164,7 @@ class LoggingContext(object):
         current = self.set_current_context(self.previous_context)
         if current is not self:
             if current is self.sentinel:
-                logger.debug("Expected logging context %s has been lost", self)
+                logger.warn("Expected logging context %s has been lost", self)
             else:
                 logger.warn(
                     "Current logging context %s is not expected context %s",
@@ -279,7 +279,7 @@ class PreserveLoggingContext(object):
         context = LoggingContext.set_current_context(self.current_context)
 
         if context != self.new_context:
-            logger.debug(
+            logger.warn(
                 "Unexpected logging context: %s is not %s",
                 context, self.new_context,
             )
@@ -302,31 +302,49 @@ def preserve_fn(f):
 def run_in_background(f, *args, **kwargs):
     """Calls a function, ensuring that the current context is restored after
     return from the function, and that the sentinel context is set once the
-    deferred returned by the funtion completes.
+    deferred returned by the function completes.
 
     Useful for wrapping functions that return a deferred which you don't yield
-    on.
+    on (for instance because you want to pass it to deferred.gatherResults()).
+
+    Note that if you completely discard the result, you should make sure that
+    `f` doesn't raise any deferred exceptions, otherwise a scary-looking
+    CRITICAL error about an unhandled error will be logged without much
+    indication about where it came from.
     """
     current = LoggingContext.current_context()
-    res = f(*args, **kwargs)
-    if isinstance(res, defer.Deferred) and not res.called:
-        # The function will have reset the context before returning, so
-        # we need to restore it now.
-        LoggingContext.set_current_context(current)
-
-        # The original context will be restored when the deferred
-        # completes, but there is nothing waiting for it, so it will
-        # get leaked into the reactor or some other function which
-        # wasn't expecting it. We therefore need to reset the context
-        # here.
-        #
-        # (If this feels asymmetric, consider it this way: we are
-        # effectively forking a new thread of execution. We are
-        # probably currently within a ``with LoggingContext()`` block,
-        # which is supposed to have a single entry and exit point. But
-        # by spawning off another deferred, we are effectively
-        # adding a new exit point.)
-        res.addBoth(_set_context_cb, LoggingContext.sentinel)
+    try:
+        res = f(*args, **kwargs)
+    except:   # noqa: E722
+        # the assumption here is that the caller doesn't want to be disturbed
+        # by synchronous exceptions, so let's turn them into Failures.
+        return defer.fail()
+
+    if not isinstance(res, defer.Deferred):
+        return res
+
+    if res.called and not res.paused:
+        # The function should have maintained the logcontext, so we can
+        # optimise out the messing about
+        return res
+
+    # The function may have reset the context before returning, so
+    # we need to restore it now.
+    ctx = LoggingContext.set_current_context(current)
+
+    # The original context will be restored when the deferred
+    # completes, but there is nothing waiting for it, so it will
+    # get leaked into the reactor or some other function which
+    # wasn't expecting it. We therefore need to reset the context
+    # here.
+    #
+    # (If this feels asymmetric, consider it this way: we are
+    # effectively forking a new thread of execution. We are
+    # probably currently within a ``with LoggingContext()`` block,
+    # which is supposed to have a single entry and exit point. But
+    # by spawning off another deferred, we are effectively
+    # adding a new exit point.)
+    res.addBoth(_set_context_cb, ctx)
     return res
 
 
@@ -341,11 +359,20 @@ def make_deferred_yieldable(deferred):
     returning a deferred. Then, when the deferred completes, restores the
     current logcontext before running callbacks/errbacks.
 
-    (This is more-or-less the opposite operation to preserve_fn.)
+    (This is more-or-less the opposite operation to run_in_background.)
     """
-    if isinstance(deferred, defer.Deferred) and not deferred.called:
-        prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
-        deferred.addBoth(_set_context_cb, prev_context)
+    if not isinstance(deferred, defer.Deferred):
+        return deferred
+
+    if deferred.called and not deferred.paused:
+        # it looks like this deferred is ready to run any callbacks we give it
+        # immediately. We may as well optimise out the logcontext faffery.
+        return deferred
+
+    # ok, we can't be sure that a yield won't block, so let's reset the
+    # logcontext, and add a callback to the deferred to restore it.
+    prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+    deferred.addBoth(_set_context_cb, prev_context)
     return deferred
 
 
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index cdbc4bffd7..3e42868ea9 100644
--- a/synapse/util/logformatter.py
+++ b/synapse/util/logformatter.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 
-import StringIO
+from six import StringIO
 import logging
 import traceback
 
@@ -32,7 +32,7 @@ class LogFormatter(logging.Formatter):
         super(LogFormatter, self).__init__(*args, **kwargs)
 
     def formatException(self, ei):
-        sio = StringIO.StringIO()
+        sio = StringIO()
         (typ, val, tb) = ei
 
         # log the stack above the exception capture point if possible, but
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 1101881a2d..0ab63c3d7d 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,7 +18,10 @@ from twisted.internet import defer
 from synapse.api.errors import LimitExceededError
 
 from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import (
+    run_in_background, make_deferred_yieldable,
+    PreserveLoggingContext,
+)
 
 import collections
 import contextlib
@@ -150,7 +153,7 @@ class _PerHostRatelimiter(object):
                 "Ratelimit [%s]: sleeping req",
                 id(request_id),
             )
-            ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
+            ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
 
             self.sleeping_requests.add(request_id)
 
@@ -176,6 +179,9 @@ class _PerHostRatelimiter(object):
             return r
 
         def on_err(r):
+            # XXX: why is this necessary? this is called before we start
+            # processing the request so why would the request be in
+            # current_processing?
             self.current_processing.discard(request_id)
             return r
 
@@ -187,7 +193,7 @@ class _PerHostRatelimiter(object):
 
         ret_defer.addCallbacks(on_start, on_err)
         ret_defer.addBoth(on_both)
-        return ret_defer
+        return make_deferred_yieldable(ret_defer)
 
     def _on_exit(self, request_id):
         logger.debug(
@@ -197,7 +203,12 @@ class _PerHostRatelimiter(object):
         self.current_processing.discard(request_id)
         try:
             request_id, deferred = self.ready_request_queue.popitem()
+
+            # XXX: why do we do the following? the on_start callback above will
+            # do it for us.
             self.current_processing.add(request_id)
-            deferred.callback(None)
+
+            with PreserveLoggingContext():
+                deferred.callback(None)
         except KeyError:
             pass
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 47b0bb5eb3..4e93f69d3a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -203,8 +203,8 @@ class RetryDestinationLimiter(object):
                 )
             except Exception:
                 logger.exception(
-                    "Failed to store set_destination_retry_timings",
+                    "Failed to store destination_retry_timings",
                 )
 
         # we deliberately do this in the background.
-        synapse.util.logcontext.preserve_fn(store_retry_timings)()
+        synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 95a6168e16..b98b9dc6e4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -15,6 +15,7 @@
 
 import random
 import string
+from six.moves import range
 
 _string_with_symbols = (
     string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -22,12 +23,12 @@ _string_with_symbols = (
 
 
 def random_string(length):
-    return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
+    return ''.join(random.choice(string.ascii_letters) for _ in range(length))
 
 
 def random_string_with_symbols(length):
     return ''.join(
-        random.choice(_string_with_symbols) for _ in xrange(length)
+        random.choice(_string_with_symbols) for _ in range(length)
     )
 
 
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index b70f9a6b0a..7a9e45aca9 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from six.moves import range
+
 
 class _Entry(object):
     __slots__ = ["end_key", "queue"]
@@ -68,7 +70,7 @@ class WheelTimer(object):
         # Add empty entries between the end of the current list and when we want
         # to insert. This ensures there are no gaps.
         self.entries.extend(
-            _Entry(key) for key in xrange(last_key, then_key + 1)
+            _Entry(key) for key in range(last_key, then_key + 1)
         )
 
         self.entries[-1].queue.append(obj)