diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 260714ccc2..07ff25cef3 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -91,8 +91,12 @@ class Clock(object):
with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
- def cancel_call_later(self, timer):
- timer.cancel()
+ def cancel_call_later(self, timer, ignore_errs=False):
+ try:
+ timer.cancel()
+ except:
+ if not ignore_errs:
+ raise
def time_bound_deferred(self, given_deferred, time_out):
if given_deferred.called:
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 1c2044e5b4..5a1d545c96 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -38,6 +38,9 @@ class ObservableDeferred(object):
deferred.
If consumeErrors is true errors will be captured from the origin deferred.
+
+ Cancelling or otherwise resolving an observer will not affect the original
+ ObservableDeferred.
"""
__slots__ = ["_deferred", "_observers", "_result"]
@@ -45,7 +48,7 @@ class ObservableDeferred(object):
def __init__(self, deferred, consumeErrors=False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
- object.__setattr__(self, "_observers", [])
+ object.__setattr__(self, "_observers", set())
def callback(r):
self._result = (True, r)
@@ -74,12 +77,21 @@ class ObservableDeferred(object):
def observe(self):
if not self._result:
d = defer.Deferred()
- self._observers.append(d)
+
+ def remove(r):
+ self._observers.discard(d)
+ return r
+ d.addBoth(remove)
+
+ self._observers.add(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
+ def observers(self):
+ return self._observers
+
def __getattr__(self, name):
return getattr(self._deferred, name)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index a92d518b43..7e6062c1b8 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -140,6 +140,37 @@ class PreserveLoggingContext(object):
)
+class _PreservingContextDeferred(defer.Deferred):
+ """A deferred that ensures that all callbacks and errbacks are called with
+ the given logging context.
+ """
+ def __init__(self, context):
+ self._log_context = context
+ defer.Deferred.__init__(self)
+
+ def addCallbacks(self, callback, errback=None,
+ callbackArgs=None, callbackKeywords=None,
+ errbackArgs=None, errbackKeywords=None):
+ callback = self._wrap_callback(callback)
+ errback = self._wrap_callback(errback)
+ return defer.Deferred.addCallbacks(
+ self, callback,
+ errback=errback,
+ callbackArgs=callbackArgs,
+ callbackKeywords=callbackKeywords,
+ errbackArgs=errbackArgs,
+ errbackKeywords=errbackKeywords,
+ )
+
+ def _wrap_callback(self, f):
+ def g(res, *args, **kwargs):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = self._log_context
+ res = f(res, *args, **kwargs)
+ return res
+ return g
+
+
def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so.
@@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
"""
- d = defer.Deferred()
-
current_context = LoggingContext.current_context()
-
- def cb(res):
- with PreserveLoggingContext():
- LoggingContext.thread_local.current_context = current_context
- res = d.callback(res)
- return res
-
- def eb(failure):
- with PreserveLoggingContext():
- LoggingContext.thread_local.current_context = current_context
- res = d.errback(failure)
- return res
-
- if deferred.called:
- return deferred
-
- deferred.addCallbacks(cb, eb)
+ d = _PreservingContextDeferred(current_context)
+ deferred.chainDeferred(d)
return d
|