summary refs log tree commit diff
path: root/tests/http/test_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/test_client.py')
-rw-r--r--tests/http/test_client.py143
1 files changed, 140 insertions, 3 deletions
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index a98091d711..721917f957 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -37,18 +37,155 @@ from synapse.http.client import (
     BlocklistingAgentWrapper,
     BlocklistingReactorWrapper,
     BodyExceededMaxSize,
+    MultipartResponse,
     _DiscardBodyWithMaxSizeProtocol,
+    _MultipartParserProtocol,
     read_body_with_max_size,
+    read_multipart_response,
 )
 
 from tests.server import FakeTransport, get_clock
 from tests.unittest import TestCase
 
 
+class ReadMultipartResponseTests(TestCase):
+    data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
+    data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+    redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+
+    def _build_multipart_response(
+        self, response_length: Union[int, str], max_length: int
+    ) -> Tuple[
+        BytesIO,
+        "Deferred[MultipartResponse]",
+        _MultipartParserProtocol,
+    ]:
+        """Start reading the body, returns the response, result and proto"""
+        response = Mock(length=response_length)
+        result = BytesIO()
+        boundary = "6067d4698f8d40a0a794ea7d7379d53a"
+        deferred = read_multipart_response(response, result, boundary, max_length)
+
+        # Fish the protocol out of the response.
+        protocol = response.deliverBody.call_args[0][0]
+        protocol.transport = Mock()
+
+        return result, deferred, protocol
+
+    def _assert_error(
+        self,
+        deferred: "Deferred[MultipartResponse]",
+        protocol: _MultipartParserProtocol,
+    ) -> None:
+        """Ensure that the expected error is received."""
+        assert isinstance(deferred.result, Failure)
+        self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
+        assert protocol.transport is not None
+        # type-ignore: presumably abortConnection has been replaced with a Mock.
+        protocol.transport.abortConnection.assert_called_once()  # type: ignore[attr-defined]
+
+    def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None:
+        """Ensure that the error in the Deferred is handled gracefully."""
+        called = [False]
+
+        def errback(f: Failure) -> None:
+            called[0] = True
+
+        deferred.addErrback(errback)
+        self.assertTrue(called[0])
+
+    def test_parse_file(self) -> None:
+        """
+        Check that a multipart response containing a file is properly parsed
+        into the json/file parts, and the json and file are properly captured
+        """
+        result, deferred, protocol = self._build_multipart_response(249, 250)
+
+        # Start sending data.
+        protocol.dataReceived(self.data1)
+        protocol.dataReceived(self.data2)
+        # Close the connection.
+        protocol.connectionLost(Failure(ResponseDone()))
+
+        multipart_response: MultipartResponse = deferred.result  # type: ignore[assignment]
+
+        self.assertEqual(multipart_response.json, b"{}")
+        self.assertEqual(result.getvalue(), b"file_to_stream")
+        self.assertEqual(multipart_response.length, len(b"file_to_stream"))
+        self.assertEqual(multipart_response.content_type, b"text/plain")
+        self.assertEqual(
+            multipart_response.disposition, b"inline; filename=test_upload"
+        )
+
+    def test_parse_redirect(self) -> None:
+        """
+        check that a multipart response containing a redirect is properly parsed and redirect url is
+        returned
+        """
+        result, deferred, protocol = self._build_multipart_response(249, 250)
+
+        # Start sending data.
+        protocol.dataReceived(self.redirect_data)
+        # Close the connection.
+        protocol.connectionLost(Failure(ResponseDone()))
+
+        multipart_response: MultipartResponse = deferred.result  # type: ignore[assignment]
+
+        self.assertEqual(multipart_response.json, b"{}")
+        self.assertEqual(result.getvalue(), b"")
+        self.assertEqual(
+            multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt"
+        )
+
+    def test_too_large(self) -> None:
+        """A response which is too large raises an exception."""
+        result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
+
+        # Start sending data.
+        protocol.dataReceived(self.data1)
+
+        self.assertEqual(result.getvalue(), b"file_")
+        self._assert_error(deferred, protocol)
+        self._cleanup_error(deferred)
+
+    def test_additional_data(self) -> None:
+        """A connection can receive data after being closed."""
+        result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
+
+        # Start sending data.
+        protocol.dataReceived(self.data1)
+        self._assert_error(deferred, protocol)
+
+        # More data might have come in.
+        protocol.dataReceived(self.data2)
+
+        self.assertEqual(result.getvalue(), b"file_")
+        self._assert_error(deferred, protocol)
+        self._cleanup_error(deferred)
+
+    def test_content_length(self) -> None:
+        """The body shouldn't be read (at all) if the Content-Length header is too large."""
+        result, deferred, protocol = self._build_multipart_response(250, 1)
+
+        # Deferred shouldn't be called yet.
+        self.assertFalse(deferred.called)
+
+        # Start sending data.
+        protocol.dataReceived(self.data1)
+        self._assert_error(deferred, protocol)
+        self._cleanup_error(deferred)
+
+        # The data is never consumed.
+        self.assertEqual(result.getvalue(), b"")
+
+
 class ReadBodyWithMaxSizeTests(TestCase):
-    def _build_response(
-        self, length: Union[int, str] = UNKNOWN_LENGTH
-    ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
+    def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[
+        BytesIO,
+        "Deferred[int]",
+        _DiscardBodyWithMaxSizeProtocol,
+    ]:
         """Start reading the body, returns the response, result and proto"""
         response = Mock(length=length)
         result = BytesIO()