diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index d9c4af9389..c0d2f06855 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -143,6 +143,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request)
return
+ logger.debug("Responding to media request with responder %s")
add_file_headers(request, media_type, file_size, upload_name)
with responder:
yield responder.write_to_consumer(request)
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 7f263db239..d23fe10b07 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -255,7 +255,9 @@ class FileResponder(Responder):
self.open_file = open_file
def write_to_consumer(self, consumer):
- return FileSender().beginFileTransfer(self.open_file, consumer)
+ return make_deferred_yieldable(
+ FileSender().beginFileTransfer(self.open_file, consumer)
+ )
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 01ac71e53e..e086e12213 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -302,7 +302,7 @@ def preserve_fn(f):
def run_in_background(f, *args, **kwargs):
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
- deferred returned by the funtion completes.
+ deferred returned by the function completes.
Useful for wrapping functions that return a deferred which you don't yield
on (for instance because you want to pass it to deferred.gatherResults()).
@@ -320,24 +320,31 @@ def run_in_background(f, *args, **kwargs):
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()
- if isinstance(res, defer.Deferred) and not res.called:
- # The function will have reset the context before returning, so
- # we need to restore it now.
- LoggingContext.set_current_context(current)
-
- # The original context will be restored when the deferred
- # completes, but there is nothing waiting for it, so it will
- # get leaked into the reactor or some other function which
- # wasn't expecting it. We therefore need to reset the context
- # here.
- #
- # (If this feels asymmetric, consider it this way: we are
- # effectively forking a new thread of execution. We are
- # probably currently within a ``with LoggingContext()`` block,
- # which is supposed to have a single entry and exit point. But
- # by spawning off another deferred, we are effectively
- # adding a new exit point.)
- res.addBoth(_set_context_cb, LoggingContext.sentinel)
+ if not isinstance(res, defer.Deferred):
+ return res
+
+ if res.called and not res.paused:
+ # The function should have maintained the logcontext, so we can
+ # optimise out the messing about
+ return res
+
+ # The function may have reset the context before returning, so
+ # we need to restore it now.
+ ctx = LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(_set_context_cb, ctx)
return res
@@ -354,9 +361,18 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
- if isinstance(deferred, defer.Deferred) and not deferred.called:
- prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
- deferred.addBoth(_set_context_cb, prev_context)
+ if not isinstance(deferred, defer.Deferred):
+ return deferred
+
+ if deferred.called and not deferred.paused:
+ # it looks like this deferred is ready to run any callbacks we give it
+ # immediately. We may as well optimise out the logcontext faffery.
+ return deferred
+
+ # ok, we can't be sure that a yield won't block, so let's reset the
+ # logcontext, and add a callback to the deferred to restore it.
+ prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ deferred.addBoth(_set_context_cb, prev_context)
return deferred
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index e5a902f734..9181692771 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -17,6 +17,8 @@ from synapse.appservice.scheduler import (
_ServiceQueuer, _TransactionController, _Recoverer
)
from twisted.internet import defer
+
+from synapse.util.logcontext import make_deferred_yieldable
from ..utils import MockClock
from mock import Mock
from tests import unittest
@@ -204,7 +206,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def test_send_single_event_with_queue(self):
d = defer.Deferred()
- self.txn_ctrl.send = Mock(return_value=d)
+ self.txn_ctrl.send = Mock(
+ side_effect=lambda x, y: make_deferred_yieldable(d),
+ )
service = Mock(id=4)
event = Mock(event_id="first")
event2 = Mock(event_id="second")
@@ -235,7 +239,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
srv_2_event2 = Mock(event_id="srv2b")
send_return_list = [srv_1_defer, srv_2_defer]
- self.txn_ctrl.send = Mock(side_effect=lambda x, y: send_return_list.pop(0))
+
+ def do_send(x, y):
+ return make_deferred_yieldable(send_return_list.pop(0))
+ self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent
self.queuer.enqueue(srv1, srv_1_event)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 575374c6a6..9962ce8a5d 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -128,7 +128,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _rotate(10)
yield _assert_counts(1, 1)
- @tests.unittest.DEBUG
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 4850722bc5..ad78d884e0 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -36,24 +36,28 @@ class LoggingContextTestCase(unittest.TestCase):
yield sleep(0)
self._check_test_key("one")
- def _test_preserve_fn(self, function):
+ def _test_run_in_background(self, function):
sentinel_context = LoggingContext.current_context()
callback_completed = [False]
- @defer.inlineCallbacks
- def cb():
+ def test():
context_one.request = "one"
- yield function()
- self._check_test_key("one")
+ d = function()
- callback_completed[0] = True
+ def cb(res):
+ self._check_test_key("one")
+ callback_completed[0] = True
+ return res
+ d.addCallback(cb)
+
+ return d
with LoggingContext() as context_one:
context_one.request = "one"
# fire off function, but don't wait on it.
- logcontext.preserve_fn(cb)()
+ logcontext.run_in_background(test)
self._check_test_key("one")
@@ -80,20 +84,30 @@ class LoggingContextTestCase(unittest.TestCase):
# test is done once d2 finishes
return d2
- def test_preserve_fn_with_blocking_fn(self):
+ def test_run_in_background_with_blocking_fn(self):
@defer.inlineCallbacks
def blocking_function():
yield sleep(0)
- return self._test_preserve_fn(blocking_function)
+ return self._test_run_in_background(blocking_function)
- def test_preserve_fn_with_non_blocking_fn(self):
+ def test_run_in_background_with_non_blocking_fn(self):
@defer.inlineCallbacks
def nonblocking_function():
with logcontext.PreserveLoggingContext():
yield defer.succeed(None)
- return self._test_preserve_fn(nonblocking_function)
+ return self._test_run_in_background(nonblocking_function)
+
+ def test_run_in_background_with_chained_deferred(self):
+ # a function which returns a deferred which looks like it has been
+ # called, but is actually paused
+ def testfunc():
+ return logcontext.make_deferred_yieldable(
+ _chained_deferred_function()
+ )
+
+ return self._test_run_in_background(testfunc)
@defer.inlineCallbacks
def test_make_deferred_yieldable(self):
@@ -119,6 +133,22 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
@defer.inlineCallbacks
+ def test_make_deferred_yieldable_with_chained_deferreds(self):
+ sentinel_context = LoggingContext.current_context()
+
+ with LoggingContext() as context_one:
+ context_one.request = "one"
+
+ d1 = logcontext.make_deferred_yieldable(_chained_deferred_function())
+ # make sure that the context was reset by make_deferred_yieldable
+ self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+ yield d1
+
+ # now it should be restored
+ self._check_test_key("one")
+
+ @defer.inlineCallbacks
def test_make_deferred_yieldable_on_non_deferred(self):
"""Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred"""
@@ -132,3 +162,17 @@ class LoggingContextTestCase(unittest.TestCase):
r = yield d1
self.assertEqual(r, "bum")
self._check_test_key("one")
+
+
+# a function which returns a deferred which has been "called", but
+# which had a function which returned another incomplete deferred on
+# its callback list, so won't yet call any other new callbacks.
+def _chained_deferred_function():
+ d = defer.succeed(None)
+
+ def cb(res):
+ d2 = defer.Deferred()
+ reactor.callLater(0, d2.callback, res)
+ return d2
+ d.addCallback(cb)
+ return d
|