diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index a98091d711..721917f957 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -37,18 +37,155 @@ from synapse.http.client import (
BlocklistingAgentWrapper,
BlocklistingReactorWrapper,
BodyExceededMaxSize,
+ MultipartResponse,
_DiscardBodyWithMaxSizeProtocol,
+ _MultipartParserProtocol,
read_body_with_max_size,
+ read_multipart_response,
)
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
+class ReadMultipartResponseTests(TestCase):
+ data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
+ data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+ redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+ def _build_multipart_response(
+ self, response_length: Union[int, str], max_length: int
+ ) -> Tuple[
+ BytesIO,
+ "Deferred[MultipartResponse]",
+ _MultipartParserProtocol,
+ ]:
+ """Start reading the body, returns the response, result and proto"""
+ response = Mock(length=response_length)
+ result = BytesIO()
+ boundary = "6067d4698f8d40a0a794ea7d7379d53a"
+ deferred = read_multipart_response(response, result, boundary, max_length)
+
+ # Fish the protocol out of the response.
+ protocol = response.deliverBody.call_args[0][0]
+ protocol.transport = Mock()
+
+ return result, deferred, protocol
+
+ def _assert_error(
+ self,
+ deferred: "Deferred[MultipartResponse]",
+ protocol: _MultipartParserProtocol,
+ ) -> None:
+ """Ensure that the expected error is received."""
+ assert isinstance(deferred.result, Failure)
+ self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
+ assert protocol.transport is not None
+ # type-ignore: presumably abortConnection has been replaced with a Mock.
+ protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
+
+ def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None:
+ """Ensure that the error in the Deferred is handled gracefully."""
+ called = [False]
+
+ def errback(f: Failure) -> None:
+ called[0] = True
+
+ deferred.addErrback(errback)
+ self.assertTrue(called[0])
+
+ def test_parse_file(self) -> None:
+ """
+ Check that a multipart response containing a file is properly parsed
+ into the json/file parts, and the json and file are properly captured
+ """
+ result, deferred, protocol = self._build_multipart_response(249, 250)
+
+ # Start sending data.
+ protocol.dataReceived(self.data1)
+ protocol.dataReceived(self.data2)
+ # Close the connection.
+ protocol.connectionLost(Failure(ResponseDone()))
+
+ multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
+
+ self.assertEqual(multipart_response.json, b"{}")
+ self.assertEqual(result.getvalue(), b"file_to_stream")
+ self.assertEqual(multipart_response.length, len(b"file_to_stream"))
+ self.assertEqual(multipart_response.content_type, b"text/plain")
+ self.assertEqual(
+ multipart_response.disposition, b"inline; filename=test_upload"
+ )
+
+ def test_parse_redirect(self) -> None:
+ """
+ check that a multipart response containing a redirect is properly parsed and redirect url is
+ returned
+ """
+ result, deferred, protocol = self._build_multipart_response(249, 250)
+
+ # Start sending data.
+ protocol.dataReceived(self.redirect_data)
+ # Close the connection.
+ protocol.connectionLost(Failure(ResponseDone()))
+
+ multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
+
+ self.assertEqual(multipart_response.json, b"{}")
+ self.assertEqual(result.getvalue(), b"")
+ self.assertEqual(
+ multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt"
+ )
+
+ def test_too_large(self) -> None:
+ """A response which is too large raises an exception."""
+ result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
+
+ # Start sending data.
+ protocol.dataReceived(self.data1)
+
+ self.assertEqual(result.getvalue(), b"file_")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ def test_additional_data(self) -> None:
+ """A connection can receive data after being closed."""
+ result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
+
+ # Start sending data.
+ protocol.dataReceived(self.data1)
+ self._assert_error(deferred, protocol)
+
+ # More data might have come in.
+ protocol.dataReceived(self.data2)
+
+ self.assertEqual(result.getvalue(), b"file_")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ def test_content_length(self) -> None:
+ """The body shouldn't be read (at all) if the Content-Length header is too large."""
+ result, deferred, protocol = self._build_multipart_response(250, 1)
+
+ # Deferred shouldn't be called yet.
+ self.assertFalse(deferred.called)
+
+ # Start sending data.
+ protocol.dataReceived(self.data1)
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ # The data is never consumed.
+ self.assertEqual(result.getvalue(), b"")
+
+
class ReadBodyWithMaxSizeTests(TestCase):
- def _build_response(
- self, length: Union[int, str] = UNKNOWN_LENGTH
- ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
+ def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[
+ BytesIO,
+ "Deferred[int]",
+ _DiscardBodyWithMaxSizeProtocol,
+ ]:
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=length)
result = BytesIO()
|