summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-19 08:07:57 -0400
committerGitHub <noreply@github.com>2020-08-19 08:07:57 -0400
commitf594e434c35ab99bc71216cbb06082aa2b975980 (patch)
tree0e074197105d52cffd5879964bc413d2587522f6
parentUpdated docs: Added note about missing 308 redirect support. (#8120) (diff)
downloadsynapse-f594e434c35ab99bc71216cbb06082aa2b975980.tar.xz
Switch the JSON byte producer from a pull to a push producer. (#8116)
-rw-r--r--changelog.d/8116.feature1
-rw-r--r--synapse/http/server.py75
-rw-r--r--tests/rest/client/v1/test_login.py16
-rw-r--r--tests/rest/client/v2_alpha/test_register.py4
-rw-r--r--tests/storage/test_cleanup_extrems.py3
5 files changed, 53 insertions, 46 deletions
diff --git a/changelog.d/8116.feature b/changelog.d/8116.feature
new file mode 100644
index 0000000000..b1eaf1e78a
--- /dev/null
+++ b/changelog.d/8116.feature
@@ -0,0 +1 @@
+Iteratively encode JSON to avoid blocking the reactor.
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 37fdf14405..8d791bd2ca 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -500,7 +500,7 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
     pass
 
 
-@implementer(interfaces.IPullProducer)
+@implementer(interfaces.IPushProducer)
 class _ByteProducer:
     """
     Iteratively write bytes to the request.
@@ -515,52 +515,64 @@ class _ByteProducer:
     ):
         self._request = request
         self._iterator = iterator
+        self._paused = False
 
-    def start(self) -> None:
-        self._request.registerProducer(self, False)
+        # Register the producer and start producing data.
+        self._request.registerProducer(self, True)
+        self.resumeProducing()
 
     def _send_data(self, data: List[bytes]) -> None:
         """
-        Send a list of strings as a response to the request.
+        Send a list of bytes as a chunk of a response.
         """
         if not data:
             return
         self._request.write(b"".join(data))
 
+    def pauseProducing(self) -> None:
+        self._paused = True
+
     def resumeProducing(self) -> None:
         # We've stopped producing in the meantime (note that this might be
         # re-entrant after calling write).
         if not self._request:
             return
 
-        # Get the next chunk and write it to the request.
-        #
-        # The output of the JSON encoder is coalesced until min_chunk_size is
-        # reached. (This is because JSON encoders produce a very small output
-        # per iteration.)
-        #
-        # Note that buffer stores a list of bytes (instead of appending to
-        # bytes) to hopefully avoid many allocations.
-        buffer = []
-        buffered_bytes = 0
-        while buffered_bytes < self.min_chunk_size:
-            try:
-                data = next(self._iterator)
-                buffer.append(data)
-                buffered_bytes += len(data)
-            except StopIteration:
-                # The entire JSON object has been serialized, write any
-                # remaining data, finalize the producer and the request, and
-                # clean-up any references.
-                self._send_data(buffer)
-                self._request.unregisterProducer()
-                self._request.finish()
-                self.stopProducing()
-                return
-
-        self._send_data(buffer)
+        self._paused = False
+
+        # Write until there's backpressure telling us to stop.
+        while not self._paused:
+            # Get the next chunk and write it to the request.
+            #
+            # The output of the JSON encoder is buffered and coalesced until
+            # min_chunk_size is reached. This is because JSON encoders produce
+            # very small output per iteration and the Request object converts
+            # each call to write() to a separate chunk. Without this there would
+            # be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n").
+            #
+            # Note that buffer stores a list of bytes (instead of appending to
+            # bytes) to hopefully avoid many allocations.
+            buffer = []
+            buffered_bytes = 0
+            while buffered_bytes < self.min_chunk_size:
+                try:
+                    data = next(self._iterator)
+                    buffer.append(data)
+                    buffered_bytes += len(data)
+                except StopIteration:
+                    # The entire JSON object has been serialized, write any
+                    # remaining data, finalize the producer and the request, and
+                    # clean-up any references.
+                    self._send_data(buffer)
+                    self._request.unregisterProducer()
+                    self._request.finish()
+                    self.stopProducing()
+                    return
+
+            self._send_data(buffer)
 
     def stopProducing(self) -> None:
+        # Clear a circular reference.
         self._request = None
 
 
@@ -620,8 +632,7 @@ def respond_with_json(
     if send_cors:
         set_cors_headers(request)
 
-    producer = _ByteProducer(request, encoder(json_object))
-    producer.start()
+    _ByteProducer(request, encoder(json_object))
     return NOT_DONE_YET
 
 
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..2668662c9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
                 "password": "monkey",
             }
-            request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            request, channel = self.make_request(b"POST", LOGIN_URL, params)
             self.render(request)
 
             if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # than 1min.
         self.assertTrue(retry_after_ms < 6000)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         params = {
             "type": "m.login.password",
             "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
             "password": "monkey",
         }
-        request_data = json.dumps(params)
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
 
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "monkey",
             }
-            request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            request, channel = self.make_request(b"POST", LOGIN_URL, params)
             self.render(request)
 
             if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "monkey",
         }
-        request_data = json.dumps(params)
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
 
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "notamonkey",
             }
-            request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            request, channel = self.make_request(b"POST", LOGIN_URL, params)
             self.render(request)
 
             if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # than 1min.
         self.assertTrue(retry_after_ms < 6000)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         params = {
             "type": "m.login.password",
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "notamonkey",
         }
-        request_data = json.dumps(params)
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
 
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 53a43038f0..2fc3a60fc5 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             else:
                 self.assertEquals(channel.result["code"], b"200", channel.result)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
         self.render(request)
@@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             else:
                 self.assertEquals(channel.result["code"], b"200", channel.result)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
         self.render(request)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 8e9a650f9f..43639ca286 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
             "3"
         ] = 300000
+
         self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
         # All entries within time frame
         self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
             3,
         )
         # Oldest room to expire
-        self.pump(1)
+        self.pump(1.01)
         self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
         self.assertEqual(
             len(