diff --git a/changelog.d/9421.bugfix b/changelog.d/9421.bugfix
new file mode 100644
index 0000000000..b73ed5664c
--- /dev/null
+++ b/changelog.d/9421.bugfix
@@ -0,0 +1 @@
+Reduce the amount of memory used when generating the URL preview of a file that is larger than the `max_spider_size`.
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 73b414ccff..e54d9bd213 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -56,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -408,6 +408,9 @@ class SimpleHttpClient:
agent=self.agent,
data=body_producer,
headers=headers,
+ # Avoid buffering the body in treq since we do not reuse
+ # response bodies.
+ unbuffered=True,
**self._extra_treq_args,
) # type: defer.Deferred
@@ -702,18 +705,6 @@ class SimpleHttpClient:
resp_headers = dict(response.headers.getAllRawHeaders())
- if (
- b"Content-Length" in resp_headers
- and max_size
- and int(resp_headers[b"Content-Length"][0]) > max_size
- ):
- logger.warning("Requested URL is too large > %r bytes" % (max_size,))
- raise SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (max_size,),
- Codes.TOO_LARGE,
- )
-
if response.code > 299:
logger.warning("Got %d when downloading %s" % (response.code, url))
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
@@ -780,7 +771,9 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
# in the meantime.
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(BodyExceededMaxSize())
- self.transport.loseConnection()
+ # Close the connection (forcefully) since all the data will get
+ # discarded anyway.
+ self.transport.abortConnection()
def connectionLost(self, reason: Failure) -> None:
# If the maximum size was already exceeded, there's nothing to do.
@@ -814,6 +807,11 @@ def read_body_with_max_size(
Returns:
A Deferred which resolves to the length of the read body.
"""
+ # If the Content-Length header gives a size larger than the maximum allowed
+ # size, do not bother downloading the body.
+ if max_size is not None and response.length != UNKNOWN_LENGTH:
+ if response.length > max_size:
+ return defer.fail(BodyExceededMaxSize())
d = defer.Deferred()
response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index f17c122e93..2d9b733be0 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -18,6 +18,7 @@ from mock import Mock
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
+from twisted.web.iweb import UNKNOWN_LENGTH
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
@@ -27,12 +28,12 @@ from tests.unittest import TestCase
class ReadBodyWithMaxSizeTests(TestCase):
def setUp(self):
"""Start reading the body, returns the response, result and proto"""
- self.response = Mock()
+ response = Mock(length=UNKNOWN_LENGTH)
self.result = BytesIO()
- self.deferred = read_body_with_max_size(self.response, self.result, 6)
+ self.deferred = read_body_with_max_size(response, self.result, 6)
# Fish the protocol out of the response.
- self.protocol = self.response.deliverBody.call_args[0][0]
+ self.protocol = response.deliverBody.call_args[0][0]
self.protocol.transport = Mock()
def _cleanup_error(self):
@@ -88,7 +89,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self.protocol.dataReceived(b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
- self.protocol.transport.loseConnection.assert_called_once()
+ self.protocol.transport.abortConnection.assert_called_once()
# More data might have come in.
self.protocol.dataReceived(b"1234567890")
|