summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/client.py31
1 files changed, 29 insertions, 2 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index a910548f1e..72901e3f95 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -748,7 +748,32 @@ class BodyExceededMaxSize(Exception):
     """The maximum allowed size of the HTTP body was exceeded."""
 
 
+class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
+    """A protocol which immediately errors upon receiving data."""
+
+    def __init__(self, deferred: defer.Deferred):
+        self.deferred = deferred
+
+    def _maybe_fail(self):
+        """
+        Report a max size exceed error and disconnect the first time this is called.
+        """
+        if not self.deferred.called:
+            self.deferred.errback(BodyExceededMaxSize())
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            self.transport.abortConnection()
+
+    def dataReceived(self, data: bytes) -> None:
+        self._maybe_fail()
+
+    def connectionLost(self, reason: Failure) -> None:
+        self._maybe_fail()
+
+
 class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
+    """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
+
     def __init__(
         self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
     ):
@@ -805,13 +830,15 @@ def read_body_with_max_size(
     Returns:
         A Deferred which resolves to the length of the read body.
     """
+    d = defer.Deferred()
+
     # If the Content-Length header gives a size larger than the maximum allowed
     # size, do not bother downloading the body.
     if max_size is not None and response.length != UNKNOWN_LENGTH:
         if response.length > max_size:
-            return defer.fail(BodyExceededMaxSize())
+            response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
+            return d
 
-    d = defer.Deferred()
     response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
     return d