diff options
author | Patrick Cloke <clokep@users.noreply.github.com> | 2021-03-01 12:45:00 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-01 12:45:00 -0500 |
commit | 16ec8c3272ba014620c95285cc7c6539ee4e9ed2 (patch) | |
tree | bc715e28b764c8975eb3b0db9ffb3e43c7e60b93 /tests/http/test_client.py | |
parent | Use the proper Request in type hints. (#9515) (diff) | |
download | synapse-16ec8c3272ba014620c95285cc7c6539ee4e9ed2.tar.xz |
(Hopefully) stop leaking file descriptors in media repo. (#9497)
By consuming the response if the headers imply that the content is too large.
Diffstat (limited to 'tests/http/test_client.py')
-rw-r--r-- | tests/http/test_client.py | 91 |
1 files changed, 55 insertions, 36 deletions
diff --git a/tests/http/test_client.py b/tests/http/test_client.py index 2d9b733be0..21ecb81c99 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -26,77 +26,96 @@ from tests.unittest import TestCase class ReadBodyWithMaxSizeTests(TestCase): - def setUp(self): + def _build_response(self, length=UNKNOWN_LENGTH): """Start reading the body, returns the response, result and proto""" - response = Mock(length=UNKNOWN_LENGTH) - self.result = BytesIO() - self.deferred = read_body_with_max_size(response, self.result, 6) + response = Mock(length=length) + result = BytesIO() + deferred = read_body_with_max_size(response, result, 6) # Fish the protocol out of the response. - self.protocol = response.deliverBody.call_args[0][0] - self.protocol.transport = Mock() + protocol = response.deliverBody.call_args[0][0] + protocol.transport = Mock() - def _cleanup_error(self): + return result, deferred, protocol + + def _assert_error(self, deferred, protocol): + """Ensure that the expected error is received.""" + self.assertIsInstance(deferred.result, Failure) + self.assertIsInstance(deferred.result.value, BodyExceededMaxSize) + protocol.transport.abortConnection.assert_called_once() + + def _cleanup_error(self, deferred): """Ensure that the error in the Deferred is handled gracefully.""" called = [False] def errback(f): called[0] = True - self.deferred.addErrback(errback) + deferred.addErrback(errback) self.assertTrue(called[0]) def test_no_error(self): """A response that is NOT too large.""" + result, deferred, protocol = self._build_response() # Start sending data. - self.protocol.dataReceived(b"12345") + protocol.dataReceived(b"12345") # Close the connection. - self.protocol.connectionLost(Failure(ResponseDone())) + protocol.connectionLost(Failure(ResponseDone())) - self.assertEqual(self.result.getvalue(), b"12345") - self.assertEqual(self.deferred.result, 5) + self.assertEqual(result.getvalue(), b"12345") + self.assertEqual(deferred.result, 5) def test_too_large(self): """A response which is too large raises an exception.""" + result, deferred, protocol = self._build_response() # Start sending data. - self.protocol.dataReceived(b"1234567890") - # Close the connection. - self.protocol.connectionLost(Failure(ResponseDone())) + protocol.dataReceived(b"1234567890") - self.assertEqual(self.result.getvalue(), b"1234567890") - self.assertIsInstance(self.deferred.result, Failure) - self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) - self._cleanup_error() + self.assertEqual(result.getvalue(), b"1234567890") + self._assert_error(deferred, protocol) + self._cleanup_error(deferred) def test_multiple_packets(self): - """Data should be accummulated through mutliple packets.""" + """Data should be accumulated through mutliple packets.""" + result, deferred, protocol = self._build_response() # Start sending data. - self.protocol.dataReceived(b"12") - self.protocol.dataReceived(b"34") + protocol.dataReceived(b"12") + protocol.dataReceived(b"34") # Close the connection. - self.protocol.connectionLost(Failure(ResponseDone())) + protocol.connectionLost(Failure(ResponseDone())) - self.assertEqual(self.result.getvalue(), b"1234") - self.assertEqual(self.deferred.result, 4) + self.assertEqual(result.getvalue(), b"1234") + self.assertEqual(deferred.result, 4) def test_additional_data(self): """A connection can receive data after being closed.""" + result, deferred, protocol = self._build_response() # Start sending data. - self.protocol.dataReceived(b"1234567890") - self.assertIsInstance(self.deferred.result, Failure) - self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) - self.protocol.transport.abortConnection.assert_called_once() + protocol.dataReceived(b"1234567890") + self._assert_error(deferred, protocol) # More data might have come in. - self.protocol.dataReceived(b"1234567890") - # Close the connection. - self.protocol.connectionLost(Failure(ResponseDone())) + protocol.dataReceived(b"1234567890") + + self.assertEqual(result.getvalue(), b"1234567890") + self._assert_error(deferred, protocol) + self._cleanup_error(deferred) + + def test_content_length(self): + """The body shouldn't be read (at all) if the Content-Length header is too large.""" + result, deferred, protocol = self._build_response(length=10) + + # Deferred shouldn't be called yet. + self.assertFalse(deferred.called) + + # Start sending data. + protocol.dataReceived(b"12345") + self._assert_error(deferred, protocol) + self._cleanup_error(deferred) - self.assertEqual(self.result.getvalue(), b"1234567890") - self.assertIsInstance(self.deferred.result, Failure) - self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) - self._cleanup_error() + # The data is never consumed. + self.assertEqual(result.getvalue(), b"") |