diff --git a/changelog.d/9833.bugfix b/changelog.d/9833.bugfix
new file mode 100644
index 0000000000..56f9c9626b
--- /dev/null
+++ b/changelog.d/9833.bugfix
@@ -0,0 +1 @@
+Limit the size of HTTP responses read over federation.
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 1730187ffa..5f40f16e24 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -33,6 +33,7 @@ import treq
from canonicaljson import encode_canonical_json
from netaddr import AddrFormatError, IPAddress, IPSet
from prometheus_client import Counter
+from typing_extensions import Protocol
from zope.interface import implementer, provider
from OpenSSL import SSL
@@ -754,6 +755,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
+class ByteWriteable(Protocol):
+ """The type of object which must be passed into read_body_with_max_size.
+
+ Typically this is a file object.
+ """
+
+ def write(self, data: bytes) -> int:
+ pass
+
+
class BodyExceededMaxSize(Exception):
"""The maximum allowed size of the HTTP body was exceeded."""
@@ -790,7 +801,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
transport = None # type: Optional[ITCPTransport]
def __init__(
- self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
+ self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int]
):
self.stream = stream
self.deferred = deferred
@@ -830,7 +841,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
def read_body_with_max_size(
- response: IResponse, stream: BinaryIO, max_size: Optional[int]
+ response: IResponse, stream: ByteWriteable, max_size: Optional[int]
) -> defer.Deferred:
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index d48721a4e2..bb837b7b19 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1,5 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import cgi
+import codecs
import logging
import random
import sys
+import typing
import urllib.parse
-from io import BytesIO
+from io import BytesIO, StringIO
from typing import Callable, Dict, List, Optional, Tuple, Union
import attr
@@ -72,6 +73,9 @@ incoming_responses_counter = Counter(
"synapse_http_matrixfederationclient_responses", "", ["method", "code"]
)
+# a federation response can be rather large (eg a big state_ids is 50M or so), so we
+# need a generous limit here.
+MAX_RESPONSE_SIZE = 100 * 1024 * 1024
MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3
@@ -167,12 +171,27 @@ async def _handle_json_response(
try:
check_content_type_is_json(response.headers)
- # Use the custom JSON decoder (partially re-implements treq.json_content).
- d = treq.text_content(response, encoding="utf-8")
- d.addCallback(json_decoder.decode)
+ buf = StringIO()
+ d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
+ def parse(_len: int):
+ return json_decoder.decode(buf.getvalue())
+
+ d.addCallback(parse)
+
body = await make_deferred_yieldable(d)
+ except BodyExceededMaxSize as e:
+ # The response was too big.
+ logger.warning(
+ "{%s} [%s] JSON response exceeded max size %i - %s %s",
+ request.txn_id,
+ request.destination,
+ MAX_RESPONSE_SIZE,
+ request.method,
+ request.uri.decode("ascii"),
+ )
+ raise RequestSendFailed(e, can_retry=False) from e
except ValueError as e:
# The JSON content was invalid.
logger.warning(
@@ -218,6 +237,18 @@ async def _handle_json_response(
return body
+class BinaryIOWrapper:
+ """A wrapper for a TextIO which converts from bytes on the fly."""
+
+ def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
+ self.decoder = codecs.getincrementaldecoder(encoding)(errors)
+ self.file = file
+
+ def write(self, b: Union[bytes, bytearray]) -> int:
+ self.file.write(self.decoder.decode(b))
+ return len(b)
+
+
class MatrixFederationHttpClient:
"""HTTP client used to talk to other homeservers over the federation
protocol. Send client certificates and signs requests.
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 9e97185507..ed9a884d76 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -26,6 +26,7 @@ from twisted.web.http import HTTPChannel
from synapse.api.errors import RequestSendFailed
from synapse.http.matrixfederationclient import (
+ MAX_RESPONSE_SIZE,
MatrixFederationHttpClient,
MatrixFederationRequest,
)
@@ -560,3 +561,61 @@ class FederationClientTests(HomeserverTestCase):
f = self.failureResultOf(test_d)
self.assertIsInstance(f.value, RequestSendFailed)
+
+ def test_too_big(self):
+ """
+ Test what happens if a huge response is returned from the remote endpoint.
+ """
+
+ test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+
+ # complete the connection and wire it up to a fake transport
+ protocol = factory.buildProtocol(None)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ # that should have made it send the request to the transport
+ self.assertRegex(transport.value(), b"^GET /foo/bar")
+ self.assertRegex(transport.value(), b"Host: testserv:8008")
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # Send it a huge HTTP response
+ protocol.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Server: Fake\r\n"
+ b"Content-Type: application/json\r\n"
+ b"\r\n"
+ )
+
+ self.pump()
+
+ # should still be waiting
+ self.assertNoResult(test_d)
+
+ sent = 0
+ chunk_size = 1024 * 512
+ while not test_d.called:
+ protocol.dataReceived(b"a" * chunk_size)
+ sent += chunk_size
+ self.assertLessEqual(sent, MAX_RESPONSE_SIZE)
+
+ self.assertEqual(sent, MAX_RESPONSE_SIZE)
+
+ f = self.failureResultOf(test_d)
+ self.assertIsInstance(f.value, RequestSendFailed)
+
+ self.assertTrue(transport.disconnecting)
|