summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/5593.bugfix1
-rw-r--r--synapse/http/server.py26
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py1
-rw-r--r--synapse/util/logcontext.py9
-rw-r--r--tests/rest/media/v1/test_url_preview.py12
-rw-r--r--tests/util/test_logcontext.py33
6 files changed, 58 insertions, 24 deletions
diff --git a/changelog.d/5593.bugfix b/changelog.d/5593.bugfix
new file mode 100644
index 0000000000..e981589ac3
--- /dev/null
+++ b/changelog.d/5593.bugfix
@@ -0,0 +1 @@
+Fix regression in 1.1rc1 where OPTIONS requests to the media repo would fail.
diff --git a/synapse/http/server.py b/synapse/http/server.py
index f067c163c1..d993161a3e 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -65,8 +65,8 @@ def wrap_json_request_handler(h):
     The handler method must have a signature of "handle_foo(self, request)",
     where "request" must be a SynapseRequest.
 
-    The handler must return a deferred. If the deferred succeeds we assume that
-    a response has been sent. If the deferred fails with a SynapseError we use
+    The handler must return a deferred or a coroutine. If the deferred succeeds
+    we assume that a response has been sent. If the deferred fails with a SynapseError we use
     it to send a JSON response with the appropriate HTTP reponse code. If the
     deferred fails with any other type of error we send a 500 reponse.
     """
@@ -353,16 +353,22 @@ class DirectServeResource(resource.Resource):
         """
         Render the request, using an asynchronous render handler if it exists.
         """
-        render_callback_name = "_async_render_" + request.method.decode("ascii")
+        async_render_callback_name = "_async_render_" + request.method.decode("ascii")
 
-        if hasattr(self, render_callback_name):
-            # Call the handler
-            callback = getattr(self, render_callback_name)
-            defer.ensureDeferred(callback(request))
+        # Try and get the async renderer
+        callback = getattr(self, async_render_callback_name, None)
 
-            return NOT_DONE_YET
-        else:
-            super().render(request)
+        # No async renderer for this request method.
+        if not callback:
+            return super().render(request)
+
+        resp = callback(request)
+
+        # If it's a coroutine, turn it into a Deferred
+        if isinstance(resp, types.CoroutineType):
+            defer.ensureDeferred(resp)
+
+        return NOT_DONE_YET
 
 
 def _options_handler(request):
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 0337b64dc2..053346fb86 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -95,6 +95,7 @@ class PreviewUrlResource(DirectServeResource):
         )
 
     def render_OPTIONS(self, request):
+        request.setHeader(b"Allow", b"OPTIONS, GET")
         return respond_with_json(request, 200, {}, send_cors=True)
 
     @wrap_json_request_handler
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 6b0d2deea0..9e1b537804 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -24,6 +24,7 @@ See doc/log_contexts.rst for details on how this works.
 
 import logging
 import threading
+import types
 
 from twisted.internet import defer, threads
 
@@ -528,8 +529,9 @@ def run_in_background(f, *args, **kwargs):
     return from the function, and that the sentinel context is set once the
     deferred returned by the function completes.
 
-    Useful for wrapping functions that return a deferred which you don't yield
-    on (for instance because you want to pass it to deferred.gatherResults()).
+    Useful for wrapping functions that return a deferred or coroutine, which you don't
+    yield or await on (for instance because you want to pass it to
+    deferred.gatherResults()).
 
     Note that if you completely discard the result, you should make sure that
     `f` doesn't raise any deferred exceptions, otherwise a scary-looking
@@ -544,6 +546,9 @@ def run_in_background(f, *args, **kwargs):
         # by synchronous exceptions, so let's turn them into Failures.
         return defer.fail()
 
+    if isinstance(res, types.CoroutineType):
+        res = defer.ensureDeferred(res)
+
     if not isinstance(res, defer.Deferred):
         return res
 
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 8fe5961866..976652aee8 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -460,3 +460,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 "error": "DNS resolution failure during URL preview generation",
             },
         )
+
+    def test_OPTIONS(self):
+        """
+        OPTIONS returns the OPTIONS.
+        """
+        request, channel = self.make_request(
+            "OPTIONS", "url_preview?url=http://example.com", shorthand=False
+        )
+        request.render(self.preview_url)
+        self.pump()
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body, {})
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8adaee3c8d..8d69fbf111 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -39,24 +39,17 @@ class LoggingContextTestCase(unittest.TestCase):
 
         callback_completed = [False]
 
-        def test():
+        with LoggingContext() as context_one:
             context_one.request = "one"
-            d = function()
+
+            # fire off function, but don't wait on it.
+            d2 = logcontext.run_in_background(function)
 
             def cb(res):
-                self._check_test_key("one")
                 callback_completed[0] = True
                 return res
 
-            d.addCallback(cb)
-
-            return d
-
-        with LoggingContext() as context_one:
-            context_one.request = "one"
-
-            # fire off function, but don't wait on it.
-            logcontext.run_in_background(test)
+            d2.addCallback(cb)
 
             self._check_test_key("one")
 
@@ -105,6 +98,22 @@ class LoggingContextTestCase(unittest.TestCase):
 
         return self._test_run_in_background(testfunc)
 
+    def test_run_in_background_with_coroutine(self):
+        async def testfunc():
+            self._check_test_key("one")
+            d = Clock(reactor).sleep(0)
+            self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+            await d
+            self._check_test_key("one")
+
+        return self._test_run_in_background(testfunc)
+
+    def test_run_in_background_with_nonblocking_coroutine(self):
+        async def testfunc():
+            self._check_test_key("one")
+
+        return self._test_run_in_background(testfunc)
+
     @defer.inlineCallbacks
     def test_make_deferred_yieldable(self):
         # a function which retuns an incomplete deferred, but doesn't follow