diff options
Diffstat (limited to '')
-rw-r--r-- | changelog.d/8013.feature | 1 | ||||
-rw-r--r-- | synapse/http/server.py | 97 | ||||
-rw-r--r-- | synapse/python_dependencies.py | 2 | ||||
-rw-r--r-- | synapse/rest/key/v2/remote_key_resource.py | 6 | ||||
-rw-r--r-- | tests/test_server.py | 1 |
5 files changed, 94 insertions, 13 deletions
diff --git a/changelog.d/8013.feature b/changelog.d/8013.feature new file mode 100644 index 0000000000..b1eaf1e78a --- /dev/null +++ b/changelog.d/8013.feature @@ -0,0 +1 @@ +Iteratively encode JSON to avoid blocking the reactor. diff --git a/synapse/http/server.py b/synapse/http/server.py index ffe6cfa09e..37fdf14405 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,12 +22,13 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Tuple, Union import jinja2 -from canonicaljson import encode_canonical_json, encode_pretty_printed_json +from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json +from zope.interface import implementer -from twisted.internet import defer +from twisted.internet import defer, interfaces from twisted.python import failure from twisted.web import resource from twisted.web.server import NOT_DONE_YET, Request @@ -499,6 +500,78 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect): pass +@implementer(interfaces.IPullProducer) +class _ByteProducer: + """ + Iteratively write bytes to the request. + """ + + # The minimum number of bytes for each chunk. Note that the last chunk will + # usually be smaller than this. + min_chunk_size = 1024 + + def __init__( + self, request: Request, iterator: Iterator[bytes], + ): + self._request = request + self._iterator = iterator + + def start(self) -> None: + self._request.registerProducer(self, False) + + def _send_data(self, data: List[bytes]) -> None: + """ + Send a list of strings as a response to the request. + """ + if not data: + return + self._request.write(b"".join(data)) + + def resumeProducing(self) -> None: + # We've stopped producing in the meantime (note that this might be + # re-entrant after calling write). + if not self._request: + return + + # Get the next chunk and write it to the request. + # + # The output of the JSON encoder is coalesced until min_chunk_size is + # reached. (This is because JSON encoders produce a very small output + # per iteration.) + # + # Note that buffer stores a list of bytes (instead of appending to + # bytes) to hopefully avoid many allocations. + buffer = [] + buffered_bytes = 0 + while buffered_bytes < self.min_chunk_size: + try: + data = next(self._iterator) + buffer.append(data) + buffered_bytes += len(data) + except StopIteration: + # The entire JSON object has been serialized, write any + # remaining data, finalize the producer and the request, and + # clean-up any references. + self._send_data(buffer) + self._request.unregisterProducer() + self._request.finish() + self.stopProducing() + return + + self._send_data(buffer) + + def stopProducing(self) -> None: + self._request = None + + +def _encode_json_bytes(json_object: Any) -> Iterator[bytes]: + """ + Encode an object into JSON. Returns an iterator of bytes. + """ + for chunk in json_encoder.iterencode(json_object): + yield chunk.encode("utf-8") + + def respond_with_json( request: Request, code: int, @@ -533,15 +606,23 @@ def respond_with_json( return None if pretty_print: - json_bytes = encode_pretty_printed_json(json_object) + b"\n" + encoder = iterencode_pretty_printed_json else: if canonical_json or synapse.events.USE_FROZEN_DICTS: - # canonicaljson already encodes to bytes - json_bytes = encode_canonical_json(json_object) + encoder = iterencode_canonical_json else: - json_bytes = json_encoder.encode(json_object).encode("utf-8") + encoder = _encode_json_bytes + + request.setResponseCode(code) + request.setHeader(b"Content-Type", b"application/json") + request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate") - return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors) + if send_cors: + set_cors_headers(request) + + producer = _ByteProducer(request, encoder(json_object)) + producer.start() + return NOT_DONE_YET def respond_with_json_bytes( diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 3250d41dde..dd77a44b8d 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -43,7 +43,7 @@ REQUIREMENTS = [ "jsonschema>=2.5.1", "frozendict>=1", "unpaddedbase64>=1.1.0", - "canonicaljson>=1.2.0", + "canonicaljson>=1.3.0", # we use the type definitions added in signedjson 1.1. "signedjson>=1.1.0", "pynacl>=1.2.1", diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 9b3f85b306..e266204f95 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,12 +15,12 @@ import logging from typing import Dict, Set -from canonicaljson import encode_canonical_json, json +from canonicaljson import json from signedjson.sign import sign_json from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes +from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request logger = logging.getLogger(__name__) @@ -223,4 +223,4 @@ class RemoteKey(DirectServeJsonResource): results = {"server_keys": signed_keys} - respond_with_json_bytes(request, 200, encode_canonical_json(results)) + respond_with_json(request, 200, results, canonical_json=True) diff --git a/tests/test_server.py b/tests/test_server.py index d628070e48..655c918a15 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -178,7 +178,6 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) - self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"]) class OptionsResourceTests(unittest.TestCase): |