summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9421.bugfix1
-rw-r--r--synapse/http/client.py26
-rw-r--r--tests/http/test_client.py9
3 files changed, 18 insertions, 18 deletions
diff --git a/changelog.d/9421.bugfix b/changelog.d/9421.bugfix
new file mode 100644
index 0000000000..b73ed5664c
--- /dev/null
+++ b/changelog.d/9421.bugfix
@@ -0,0 +1 @@
+Reduce the amount of memory used when generating the URL preview of a file that is larger than the `max_spider_size`.
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 73b414ccff..e54d9bd213 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -56,7 +56,7 @@ from twisted.web.client import (
 )
 from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
 
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -408,6 +408,9 @@ class SimpleHttpClient:
                     agent=self.agent,
                     data=body_producer,
                     headers=headers,
+                    # Avoid buffering the body in treq since we do not reuse
+                    # response bodies.
+                    unbuffered=True,
                     **self._extra_treq_args,
                 )  # type: defer.Deferred
 
@@ -702,18 +705,6 @@ class SimpleHttpClient:
 
         resp_headers = dict(response.headers.getAllRawHeaders())
 
-        if (
-            b"Content-Length" in resp_headers
-            and max_size
-            and int(resp_headers[b"Content-Length"][0]) > max_size
-        ):
-            logger.warning("Requested URL is too large > %r bytes" % (max_size,))
-            raise SynapseError(
-                502,
-                "Requested file is too large > %r bytes" % (max_size,),
-                Codes.TOO_LARGE,
-            )
-
         if response.code > 299:
             logger.warning("Got %d when downloading %s" % (response.code, url))
             raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
@@ -780,7 +771,9 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
         # in the meantime.
         if self.max_size is not None and self.length >= self.max_size:
             self.deferred.errback(BodyExceededMaxSize())
-            self.transport.loseConnection()
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            self.transport.abortConnection()
 
     def connectionLost(self, reason: Failure) -> None:
         # If the maximum size was already exceeded, there's nothing to do.
@@ -814,6 +807,11 @@ def read_body_with_max_size(
     Returns:
         A Deferred which resolves to the length of the read body.
     """
+    # If the Content-Length header gives a size larger than the maximum allowed
+    # size, do not bother downloading the body.
+    if max_size is not None and response.length != UNKNOWN_LENGTH:
+        if response.length > max_size:
+            return defer.fail(BodyExceededMaxSize())
 
     d = defer.Deferred()
     response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index f17c122e93..2d9b733be0 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -18,6 +18,7 @@ from mock import Mock
 
 from twisted.python.failure import Failure
 from twisted.web.client import ResponseDone
+from twisted.web.iweb import UNKNOWN_LENGTH
 
 from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
 
@@ -27,12 +28,12 @@ from tests.unittest import TestCase
 class ReadBodyWithMaxSizeTests(TestCase):
     def setUp(self):
         """Start reading the body, returns the response, result and proto"""
-        self.response = Mock()
+        response = Mock(length=UNKNOWN_LENGTH)
         self.result = BytesIO()
-        self.deferred = read_body_with_max_size(self.response, self.result, 6)
+        self.deferred = read_body_with_max_size(response, self.result, 6)
 
         # Fish the protocol out of the response.
-        self.protocol = self.response.deliverBody.call_args[0][0]
+        self.protocol = response.deliverBody.call_args[0][0]
         self.protocol.transport = Mock()
 
     def _cleanup_error(self):
@@ -88,7 +89,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self.protocol.dataReceived(b"1234567890")
         self.assertIsInstance(self.deferred.result, Failure)
         self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
-        self.protocol.transport.loseConnection.assert_called_once()
+        self.protocol.transport.abortConnection.assert_called_once()
 
         # More data might have come in.
         self.protocol.dataReceived(b"1234567890")