diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 7e2f2a01cc..9cfe1ad0de 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -13,10 +13,12 @@
# limitations under the License.
from io import BytesIO
+from typing import Tuple, Union
from unittest.mock import Mock
from netaddr import IPSet
+from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
@@ -28,6 +30,7 @@ from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
+ _DiscardBodyWithMaxSizeProtocol,
read_body_with_max_size,
)
@@ -36,7 +39,9 @@ from tests.unittest import TestCase
class ReadBodyWithMaxSizeTests(TestCase):
- def _build_response(self, length=UNKNOWN_LENGTH):
+ def _build_response(
+ self, length: Union[int, str] = UNKNOWN_LENGTH
+ ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=length)
result = BytesIO()
@@ -48,23 +53,27 @@ class ReadBodyWithMaxSizeTests(TestCase):
return result, deferred, protocol
- def _assert_error(self, deferred, protocol):
+ def _assert_error(
+ self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol
+ ) -> None:
"""Ensure that the expected error is received."""
- self.assertIsInstance(deferred.result, Failure)
+ assert isinstance(deferred.result, Failure)
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
- protocol.transport.abortConnection.assert_called_once()
+ assert protocol.transport is not None
+ # type-ignore: presumably abortConnection has been replaced with a Mock.
+ protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
- def _cleanup_error(self, deferred):
+ def _cleanup_error(self, deferred: "Deferred[int]") -> None:
"""Ensure that the error in the Deferred is handled gracefully."""
called = [False]
- def errback(f):
+ def errback(f: Failure) -> None:
called[0] = True
deferred.addErrback(errback)
self.assertTrue(called[0])
- def test_no_error(self):
+ def test_no_error(self) -> None:
"""A response that is NOT too large."""
result, deferred, protocol = self._build_response()
@@ -76,7 +85,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self.assertEqual(result.getvalue(), b"12345")
self.assertEqual(deferred.result, 5)
- def test_too_large(self):
+ def test_too_large(self) -> None:
"""A response which is too large raises an exception."""
result, deferred, protocol = self._build_response()
@@ -87,7 +96,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
- def test_multiple_packets(self):
+ def test_multiple_packets(self) -> None:
"""Data should be accumulated through mutliple packets."""
result, deferred, protocol = self._build_response()
@@ -100,7 +109,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self.assertEqual(result.getvalue(), b"1234")
self.assertEqual(deferred.result, 4)
- def test_additional_data(self):
+ def test_additional_data(self) -> None:
"""A connection can receive data after being closed."""
result, deferred, protocol = self._build_response()
@@ -115,7 +124,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
- def test_content_length(self):
+ def test_content_length(self) -> None:
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
result, deferred, protocol = self._build_response(length=10)
@@ -132,7 +141,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
class BlacklistingAgentTest(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.reactor, self.clock = get_clock()
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
@@ -151,7 +160,7 @@ class BlacklistingAgentTest(TestCase):
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
self.ip_blacklist = IPSet(["5.0.0.0/8"])
- def test_reactor(self):
+ def test_reactor(self) -> None:
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
BlacklistingReactorWrapper(
@@ -197,7 +206,7 @@ class BlacklistingAgentTest(TestCase):
response = self.successResultOf(d)
self.assertEqual(response.code, 200)
- def test_agent(self):
+ def test_agent(self) -> None:
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
Agent(self.reactor),
|