summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8013.feature1
-rw-r--r--synapse/http/server.py97
-rw-r--r--synapse/python_dependencies.py2
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py6
-rw-r--r--tests/test_server.py1
5 files changed, 94 insertions, 13 deletions
diff --git a/changelog.d/8013.feature b/changelog.d/8013.feature
new file mode 100644
index 0000000000..b1eaf1e78a
--- /dev/null
+++ b/changelog.d/8013.feature
@@ -0,0 +1 @@
+Iteratively encode JSON to avoid blocking the reactor.
diff --git a/synapse/http/server.py b/synapse/http/server.py
index ffe6cfa09e..37fdf14405 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,12 +22,13 @@ import types
 import urllib
 from http import HTTPStatus
 from io import BytesIO
-from typing import Any, Callable, Dict, Tuple, Union
+from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
 
 import jinja2
-from canonicaljson import encode_canonical_json, encode_pretty_printed_json
+from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
+from zope.interface import implementer
 
-from twisted.internet import defer
+from twisted.internet import defer, interfaces
 from twisted.python import failure
 from twisted.web import resource
 from twisted.web.server import NOT_DONE_YET, Request
@@ -499,6 +500,78 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
     pass
 
 
+@implementer(interfaces.IPullProducer)
+class _ByteProducer:
+    """
+    Iteratively write bytes to the request.
+    """
+
+    # The minimum number of bytes for each chunk. Note that the last chunk will
+    # usually be smaller than this.
+    min_chunk_size = 1024
+
+    def __init__(
+        self, request: Request, iterator: Iterator[bytes],
+    ):
+        self._request = request
+        self._iterator = iterator
+
+    def start(self) -> None:
+        self._request.registerProducer(self, False)
+
+    def _send_data(self, data: List[bytes]) -> None:
+        """
+        Send a list of strings as a response to the request.
+        """
+        if not data:
+            return
+        self._request.write(b"".join(data))
+
+    def resumeProducing(self) -> None:
+        # We've stopped producing in the meantime (note that this might be
+        # re-entrant after calling write).
+        if not self._request:
+            return
+
+        # Get the next chunk and write it to the request.
+        #
+        # The output of the JSON encoder is coalesced until min_chunk_size is
+        # reached. (This is because JSON encoders produce a very small output
+        # per iteration.)
+        #
+        # Note that buffer stores a list of bytes (instead of appending to
+        # bytes) to hopefully avoid many allocations.
+        buffer = []
+        buffered_bytes = 0
+        while buffered_bytes < self.min_chunk_size:
+            try:
+                data = next(self._iterator)
+                buffer.append(data)
+                buffered_bytes += len(data)
+            except StopIteration:
+                # The entire JSON object has been serialized, write any
+                # remaining data, finalize the producer and the request, and
+                # clean-up any references.
+                self._send_data(buffer)
+                self._request.unregisterProducer()
+                self._request.finish()
+                self.stopProducing()
+                return
+
+        self._send_data(buffer)
+
+    def stopProducing(self) -> None:
+        self._request = None
+
+
+def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
+    """
+    Encode an object into JSON. Returns an iterator of bytes.
+    """
+    for chunk in json_encoder.iterencode(json_object):
+        yield chunk.encode("utf-8")
+
+
 def respond_with_json(
     request: Request,
     code: int,
@@ -533,15 +606,23 @@ def respond_with_json(
         return None
 
     if pretty_print:
-        json_bytes = encode_pretty_printed_json(json_object) + b"\n"
+        encoder = iterencode_pretty_printed_json
     else:
         if canonical_json or synapse.events.USE_FROZEN_DICTS:
-            # canonicaljson already encodes to bytes
-            json_bytes = encode_canonical_json(json_object)
+            encoder = iterencode_canonical_json
         else:
-            json_bytes = json_encoder.encode(json_object).encode("utf-8")
+            encoder = _encode_json_bytes
+
+    request.setResponseCode(code)
+    request.setHeader(b"Content-Type", b"application/json")
+    request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
 
-    return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
+    if send_cors:
+        set_cors_headers(request)
+
+    producer = _ByteProducer(request, encoder(json_object))
+    producer.start()
+    return NOT_DONE_YET
 
 
 def respond_with_json_bytes(
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 3250d41dde..dd77a44b8d 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -43,7 +43,7 @@ REQUIREMENTS = [
     "jsonschema>=2.5.1",
     "frozendict>=1",
     "unpaddedbase64>=1.1.0",
-    "canonicaljson>=1.2.0",
+    "canonicaljson>=1.3.0",
     # we use the type definitions added in signedjson 1.1.
     "signedjson>=1.1.0",
     "pynacl>=1.2.1",
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9b3f85b306..e266204f95 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,12 +15,12 @@
 import logging
 from typing import Dict, Set
 
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import json
 from signedjson.sign import sign_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
+from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
 
 logger = logging.getLogger(__name__)
@@ -223,4 +223,4 @@ class RemoteKey(DirectServeJsonResource):
 
             results = {"server_keys": signed_keys}
 
-            respond_with_json_bytes(request, 200, encode_canonical_json(results))
+            respond_with_json(request, 200, results, canonical_json=True)
diff --git a/tests/test_server.py b/tests/test_server.py
index d628070e48..655c918a15 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -178,7 +178,6 @@ class JsonResourceTests(unittest.TestCase):
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertNotIn("body", channel.result)
-        self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"])
 
 
 class OptionsResourceTests(unittest.TestCase):