diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index ef10ec0937..cdc36b8d25 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -465,8 +465,9 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
if (
- self.hs.config.federation_domain_whitelist is not None
- and request.destination not in self.hs.config.federation_domain_whitelist
+ self.hs.config.federation.federation_domain_whitelist is not None
+ and request.destination
+ not in self.hs.config.federation.federation_domain_whitelist
):
raise FederationDeniedError(request.destination)
@@ -1186,7 +1187,7 @@ class MatrixFederationHttpClient:
request.method,
request.uri.decode("ascii"),
)
- return (length, headers)
+ return length, headers
def _flatten_response_never_received(e):
diff --git a/synapse/http/server.py b/synapse/http/server.py
index b79fa722e9..1a50305dcf 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,7 +21,6 @@ import types
import urllib
from http import HTTPStatus
from inspect import isawaitable
-from io import BytesIO
from typing import (
Any,
Awaitable,
@@ -37,7 +36,7 @@ from typing import (
)
import jinja2
-from canonicaljson import iterencode_canonical_json
+from canonicaljson import encode_canonical_json
from typing_extensions import Protocol
from zope.interface import implementer
@@ -45,7 +44,7 @@ from twisted.internet import defer, interfaces
from twisted.python import failure
from twisted.web import resource
from twisted.web.server import NOT_DONE_YET, Request
-from twisted.web.static import File, NoRangeStaticProducer
+from twisted.web.static import File
from twisted.web.util import redirectTo
from synapse.api.errors import (
@@ -56,10 +55,11 @@ from synapse.api.errors import (
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.context import preserve_fn
+from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
+from synapse.util.iterutils import chunk_seq
logger = logging.getLogger(__name__)
@@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response(
self,
- request: Request,
+ request: SynapseRequest,
code: int,
response_object: Any,
):
@@ -620,16 +620,15 @@ class _ByteProducer:
self._request = None
-def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
+def _encode_json_bytes(json_object: Any) -> bytes:
"""
Encode an object into JSON. Returns an iterator of bytes.
"""
- for chunk in json_encoder.iterencode(json_object):
- yield chunk.encode("utf-8")
+ return json_encoder.encode(json_object).encode("utf-8")
def respond_with_json(
- request: Request,
+ request: SynapseRequest,
code: int,
json_object: Any,
send_cors: bool = False,
@@ -659,7 +658,7 @@ def respond_with_json(
return None
if canonical_json:
- encoder = iterencode_canonical_json
+ encoder = encode_canonical_json
else:
encoder = _encode_json_bytes
@@ -670,7 +669,9 @@ def respond_with_json(
if send_cors:
set_cors_headers(request)
- _ByteProducer(request, encoder(json_object))
+ run_in_background(
+ _async_write_json_to_request_in_thread, request, encoder, json_object
+ )
return NOT_DONE_YET
@@ -706,15 +707,56 @@ def respond_with_json_bytes(
if send_cors:
set_cors_headers(request)
- # note that this is zero-copy (the bytesio shares a copy-on-write buffer with
- # the original `bytes`).
- bytes_io = BytesIO(json_bytes)
-
- producer = NoRangeStaticProducer(request, bytes_io)
- producer.start()
+ _write_bytes_to_request(request, json_bytes)
return NOT_DONE_YET
+async def _async_write_json_to_request_in_thread(
+ request: SynapseRequest,
+ json_encoder: Callable[[Any], bytes],
+ json_object: Any,
+):
+ """Encodes the given JSON object on a thread and then writes it to the
+ request.
+
+ This is done so that encoding large JSON objects doesn't block the reactor
+ thread.
+
+ Note: We don't use JsonEncoder.iterencode here as that falls back to the
+ Python implementation (rather than the C backend), which is *much* more
+ expensive.
+ """
+
+ json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
+
+ _write_bytes_to_request(request, json_str)
+
+
+def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
+ """Writes the bytes to the request using an appropriate producer.
+
+ Note: This should be used instead of `Request.write` to correctly handle
+ large response bodies.
+ """
+
+ # The problem with dumping all of the response into the `Request` object at
+ # once (via `Request.write`) is that doing so starts the timeout for the
+ # next request to be received: so if it takes longer than 60s to stream back
+ # the response to the client, the client never gets it.
+ #
+ # The correct solution is to use a Producer; then the timeout is only
+ # started once all of the content is sent over the TCP connection.
+
+ # To make sure we don't write all of the bytes at once we split it up into
+ # chunks.
+ chunk_size = 4096
+ bytes_generator = chunk_seq(bytes_to_write, chunk_size)
+
+ # We use a `_ByteProducer` here rather than `NoRangeStaticProducer` as the
+ # unit tests can't cope with being given a pull producer.
+ _ByteProducer(request, bytes_generator)
+
+
def set_cors_headers(request: Request):
"""Set the CORS headers so that javascript running in a web browsers can
use this API
diff --git a/synapse/http/site.py b/synapse/http/site.py
index dd4c749e16..755ad56637 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,13 +14,14 @@
import contextlib
import logging
import time
-from typing import Optional, Tuple, Union
+from typing import Generator, Optional, Tuple, Union
import attr
from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
+from twisted.web.http import HTTPChannel
from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site
@@ -61,10 +62,18 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
- def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
- Request.__init__(self, channel, *args, **kw)
+ def __init__(
+ self,
+ channel: HTTPChannel,
+ site: "SynapseSite",
+ *args,
+ max_request_body_size: int = 1024,
+ **kw,
+ ):
+ super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
- self.site: SynapseSite = channel.site
+ self.synapse_site = site
+ self.reactor = site.reactor
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -97,7 +106,7 @@ class SynapseRequest(Request):
self.get_method(),
self.get_redacted_uri(),
self.clientproto.decode("ascii", errors="replace"),
- self.site.site_tag,
+ self.synapse_site.site_tag,
)
def handleContentChunk(self, data: bytes) -> None:
@@ -216,7 +225,7 @@ class SynapseRequest(Request):
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
- site_tag=self.site.site_tag,
+ site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
authenticated_entity=None,
@@ -228,7 +237,7 @@ class SynapseRequest(Request):
)
# override the Server header which is set by twisted
- self.setHeader("Server", self.site.server_version_string)
+ self.setHeader("Server", self.synapse_site.server_version_string)
with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab
@@ -247,7 +256,7 @@ class SynapseRequest(Request):
requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager
- def processing(self):
+ def processing(self) -> Generator[None, None, None]:
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@@ -346,10 +355,10 @@ class SynapseRequest(Request):
self.start_time, name=servlet_name, method=self.get_method()
)
- self.site.access_logger.debug(
+ self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s",
self.getClientIP(),
- self.site.site_tag,
+ self.synapse_site.site_tag,
self.get_method(),
self.get_redacted_uri(),
)
@@ -388,13 +397,13 @@ class SynapseRequest(Request):
if authenticated_entity:
requester = f"{authenticated_entity}|{requester}"
- self.site.access_logger.log(
+ self.synapse_site.access_logger.log(
log_level,
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
- self.site.site_tag,
+ self.synapse_site.site_tag,
requester,
processing_time,
response_send_time,
@@ -522,7 +531,7 @@ class SynapseSite(Site):
site_tag: str,
config: ListenerConfig,
resource: IResource,
- server_version_string,
+ server_version_string: str,
max_request_body_size: int,
reactor: IReactorTime,
):
@@ -542,6 +551,7 @@ class SynapseSite(Site):
Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag
+ self.reactor = reactor
assert config.http_options is not None
proxied = config.http_options.x_forwarded
@@ -550,6 +560,7 @@ class SynapseSite(Site):
def request_factory(channel, queued: bool) -> Request:
return request_class(
channel,
+ self,
max_request_body_size=max_request_body_size,
queued=queued,
)
|