summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2018-08-17 14:58:45 +0100
committerGitHub <noreply@github.com>2018-08-17 14:58:45 +0100
commit63260397c6ab6925da0ba3bd24d398001cd8208b (patch)
tree3ef57e48328cd2364afee5810f83d550e0304f8a
parentMerge pull request #3700 from matrix-org/rav/wait_for_producers (diff)
parentchangelog (diff)
downloadsynapse-63260397c6ab6925da0ba3bd24d398001cd8208b.tar.xz
Merge pull request #3701 from matrix-org/rav/use_producer_for_responses
Use a producer to stream back responses
-rw-r--r--changelog.d/3701.bugfix1
-rw-r--r--synapse/http/server.py17
-rw-r--r--tests/rest/client/v1/test_register.py7
-rw-r--r--tests/rest/client/v1/utils.py8
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py23
-rw-r--r--tests/rest/client/v2_alpha/test_register.py26
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py5
-rw-r--r--tests/server.py23
-rw-r--r--tests/test_server.py17
9 files changed, 63 insertions, 64 deletions
diff --git a/changelog.d/3701.bugfix b/changelog.d/3701.bugfix
new file mode 100644
index 0000000000..c22de34537
--- /dev/null
+++ b/changelog.d/3701.bugfix
@@ -0,0 +1 @@
+Avoid timing out requests while we are streaming back the response
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 4e4c02e170..2d5c23e673 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -27,6 +27,7 @@ from twisted.internet import defer
 from twisted.python import failure
 from twisted.web import resource
 from twisted.web.server import NOT_DONE_YET
+from twisted.web.static import NoRangeStaticProducer
 from twisted.web.util import redirectTo
 
 import synapse.events
@@ -40,6 +41,11 @@ from synapse.api.errors import (
 from synapse.util.caches import intern_dict
 from synapse.util.logcontext import preserve_fn
 
+if PY3:
+    from io import BytesIO
+else:
+    from cStringIO import StringIO as BytesIO
+
 logger = logging.getLogger(__name__)
 
 HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
@@ -389,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
         return
 
     if pretty_print:
-        json_bytes = (encode_pretty_printed_json(json_object) + "\n"
-                      ).encode("utf-8")
+        json_bytes = encode_pretty_printed_json(json_object) + b"\n"
     else:
         if canonical_json or synapse.events.USE_FROZEN_DICTS:
             # canonicaljson already encodes to bytes
@@ -426,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
     if send_cors:
         set_cors_headers(request)
 
-    request.write(json_bytes)
-    finish_request(request)
+    # todo: we can almost certainly avoid this copy and encode the json straight into
+    # the bytesIO, but it would involve faffing around with string->bytes wrappers.
+    bytes_io = BytesIO(json_bytes)
+
+    producer = NoRangeStaticProducer(request, bytes_io)
+    producer.start()
     return NOT_DONE_YET
 
 
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index 4be88b8a39..6b7ff813d5 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -25,7 +25,7 @@ from synapse.rest.client.v1_only.register import register_servlets
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import make_request, setup_test_homeserver
+from tests.server import make_request, render, setup_test_homeserver
 
 
 class CreateUserServletTestCase(unittest.TestCase):
@@ -77,10 +77,7 @@ class CreateUserServletTestCase(unittest.TestCase):
         )
 
         request, channel = make_request(b"POST", url, request_data)
-        request.render(res)
-
-        # Advance the clock because it waits
-        self.clock.advance(1)
+        render(request, res, self.clock)
 
         self.assertEquals(channel.result["code"], b"200")
 
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 9f862f9dfa..40dc4ea256 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -23,7 +23,7 @@ from twisted.internet import defer
 from synapse.api.constants import Membership
 
 from tests import unittest
-from tests.server import make_request, wait_until_result
+from tests.server import make_request, render
 
 
 class RestTestCase(unittest.TestCase):
@@ -171,8 +171,7 @@ class RestHelper(object):
         request, channel = make_request(
             "POST", path, json.dumps(content).encode('utf8')
         )
-        request.render(self.resource)
-        wait_until_result(self.hs.get_reactor(), channel)
+        render(request, self.resource, self.hs.get_reactor())
 
         assert channel.result["code"] == b"200", channel.result
         self.auth_user_id = temp_id
@@ -220,8 +219,7 @@ class RestHelper(object):
 
         request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
 
-        request.render(self.resource)
-        wait_until_result(self.hs.get_reactor(), channel)
+        render(request, self.resource, self.hs.get_reactor())
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 8260c130f8..6a886ee3b8 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -24,8 +24,8 @@ from tests import unittest
 from tests.server import (
     ThreadedMemoryReactorClock as MemoryReactorClock,
     make_request,
+    render,
     setup_test_homeserver,
-    wait_until_result,
 )
 
 PATH_PREFIX = "/_matrix/client/v2_alpha"
@@ -76,8 +76,7 @@ class FilterTestCase(unittest.TestCase):
             "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
             self.EXAMPLE_FILTER_JSON,
         )
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertEqual(channel.json_body, {"filter_id": "0"})
@@ -91,8 +90,7 @@ class FilterTestCase(unittest.TestCase):
             "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
             self.EXAMPLE_FILTER_JSON,
         )
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEqual(channel.result["code"], b"403")
         self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@@ -105,8 +103,7 @@ class FilterTestCase(unittest.TestCase):
             "/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
             self.EXAMPLE_FILTER_JSON,
         )
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.hs.is_mine = _is_mine
         self.assertEqual(channel.result["code"], b"403")
@@ -121,8 +118,7 @@ class FilterTestCase(unittest.TestCase):
         request, channel = make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
         )
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
@@ -131,8 +127,7 @@ class FilterTestCase(unittest.TestCase):
         request, channel = make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
         )
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEqual(channel.result["code"], b"400")
         self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@@ -143,8 +138,7 @@ class FilterTestCase(unittest.TestCase):
         request, channel = make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
         )
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEqual(channel.result["code"], b"400")
 
@@ -153,7 +147,6 @@ class FilterTestCase(unittest.TestCase):
         request, channel = make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
         )
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEqual(channel.result["code"], b"400")
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index b72bd0fb7f..1c128e81f5 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -11,7 +11,7 @@ from synapse.rest.client.v2_alpha.register import register_servlets
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import make_request, setup_test_homeserver, wait_until_result
+from tests.server import make_request, render, setup_test_homeserver
 
 
 class RegisterRestServletTestCase(unittest.TestCase):
@@ -72,8 +72,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         request, channel = make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
         det_data = {
@@ -89,16 +88,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
         request, channel = make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
 
     def test_POST_bad_password(self):
         request_data = json.dumps({"username": "kermit", "password": 666})
         request, channel = make_request(b"POST", self.url, request_data)
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid password")
@@ -106,8 +103,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
     def test_POST_bad_username(self):
         request_data = json.dumps({"username": 777, "password": "monkey"})
         request, channel = make_request(b"POST", self.url, request_data)
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid username")
@@ -126,8 +122,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.device_handler.check_device_registered = Mock(return_value=device_id)
 
         request, channel = make_request(b"POST", self.url, request_data)
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         det_data = {
             "user_id": user_id,
@@ -149,8 +144,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.registration_handler.register = Mock(return_value=("@user:id", "t"))
 
         request, channel = make_request(b"POST", self.url, request_data)
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Registration has been disabled")
@@ -162,8 +156,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.registration_handler.register = Mock(return_value=(user_id, None))
 
         request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         det_data = {
             "user_id": user_id,
@@ -177,8 +170,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.hs.config.allow_guest_access = False
 
         request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 2e1d06c509..9f3d8bd1db 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -23,8 +23,8 @@ from tests import unittest
 from tests.server import (
     ThreadedMemoryReactorClock as MemoryReactorClock,
     make_request,
+    render,
     setup_test_homeserver,
-    wait_until_result,
 )
 
 PATH_PREFIX = "/_matrix/client/v2_alpha"
@@ -69,8 +69,7 @@ class FilterTestCase(unittest.TestCase):
 
     def test_sync_argless(self):
         request, channel = make_request("GET", "/_matrix/client/r0/sync")
-        request.render(self.resource)
-        wait_until_result(self.clock, channel)
+        render(request, self.resource, self.clock)
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertTrue(
diff --git a/tests/server.py b/tests/server.py
index beb24cf032..c63b2c3100 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -24,6 +24,7 @@ class FakeChannel(object):
     """
 
     result = attr.ib(default=attr.Factory(dict))
+    _producer = None
 
     @property
     def json_body(self):
@@ -49,6 +50,15 @@ class FakeChannel(object):
 
         self.result["body"] += content
 
+    def registerProducer(self, producer, streaming):
+        self._producer = producer
+
+    def unregisterProducer(self):
+        if self._producer is None:
+            return
+
+        self._producer = None
+
     def requestDone(self, _self):
         self.result["done"] = True
 
@@ -111,14 +121,19 @@ def make_request(method, path, content=b""):
     return req, channel
 
 
-def wait_until_result(clock, channel, timeout=100):
+def wait_until_result(clock, request, timeout=100):
     """
-    Wait until the channel has a result.
+    Wait until the request is finished.
     """
     clock.run()
     x = 0
 
-    while not channel.result:
+    while not request.finished:
+
+        # If there's a producer, tell it to resume producing so we get content
+        if request._channel._producer:
+            request._channel._producer.resumeProducing()
+
         x += 1
 
         if x > timeout:
@@ -129,7 +144,7 @@ def wait_until_result(clock, channel, timeout=100):
 
 def render(request, resource, clock):
     request.render(resource)
-    wait_until_result(clock, request._channel)
+    wait_until_result(clock, request)
 
 
 class ThreadedMemoryReactorClock(MemoryReactorClock):
diff --git a/tests/test_server.py b/tests/test_server.py
index 895e490406..ef74544e93 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -8,7 +8,7 @@ from synapse.http.server import JsonResource
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import make_request, setup_test_homeserver
+from tests.server import make_request, render, setup_test_homeserver
 
 
 class JsonResourceTests(unittest.TestCase):
@@ -37,7 +37,7 @@ class JsonResourceTests(unittest.TestCase):
         )
 
         request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
-        request.render(res)
+        render(request, res, self.reactor)
 
         self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
         self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
@@ -55,7 +55,7 @@ class JsonResourceTests(unittest.TestCase):
         res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
 
         request, channel = make_request(b"GET", b"/_matrix/foo")
-        request.render(res)
+        render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b'500')
 
@@ -78,13 +78,8 @@ class JsonResourceTests(unittest.TestCase):
         res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
 
         request, channel = make_request(b"GET", b"/_matrix/foo")
-        request.render(res)
+        render(request, res, self.reactor)
 
-        # No error has been raised yet
-        self.assertTrue("code" not in channel.result)
-
-        # Advance time, now there's an error
-        self.reactor.advance(1)
         self.assertEqual(channel.result["code"], b'500')
 
     def test_callback_synapseerror(self):
@@ -100,7 +95,7 @@ class JsonResourceTests(unittest.TestCase):
         res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
 
         request, channel = make_request(b"GET", b"/_matrix/foo")
-        request.render(res)
+        render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b'403')
         self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -121,7 +116,7 @@ class JsonResourceTests(unittest.TestCase):
         res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
 
         request, channel = make_request(b"GET", b"/_matrix/foobar")
-        request.render(res)
+        render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b'400')
         self.assertEqual(channel.json_body["error"], "Unrecognized request")