summary refs log tree commit diff
path: root/tests/http/test_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/test_client.py')
-rw-r--r--tests/http/test_client.py37
1 files changed, 23 insertions, 14 deletions
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 7e2f2a01cc..9cfe1ad0de 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -13,10 +13,12 @@
 #  limitations under the License.
 
 from io import BytesIO
+from typing import Tuple, Union
 from unittest.mock import Mock
 
 from netaddr import IPSet
 
+from twisted.internet.defer import Deferred
 from twisted.internet.error import DNSLookupError
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol
@@ -28,6 +30,7 @@ from synapse.http.client import (
     BlacklistingAgentWrapper,
     BlacklistingReactorWrapper,
     BodyExceededMaxSize,
+    _DiscardBodyWithMaxSizeProtocol,
     read_body_with_max_size,
 )
 
@@ -36,7 +39,9 @@ from tests.unittest import TestCase
 
 
 class ReadBodyWithMaxSizeTests(TestCase):
-    def _build_response(self, length=UNKNOWN_LENGTH):
+    def _build_response(
+        self, length: Union[int, str] = UNKNOWN_LENGTH
+    ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
         """Start reading the body, returns the response, result and proto"""
         response = Mock(length=length)
         result = BytesIO()
@@ -48,23 +53,27 @@ class ReadBodyWithMaxSizeTests(TestCase):
 
         return result, deferred, protocol
 
-    def _assert_error(self, deferred, protocol):
+    def _assert_error(
+        self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol
+    ) -> None:
         """Ensure that the expected error is received."""
-        self.assertIsInstance(deferred.result, Failure)
+        assert isinstance(deferred.result, Failure)
         self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
-        protocol.transport.abortConnection.assert_called_once()
+        assert protocol.transport is not None
+        # type-ignore: presumably abortConnection has been replaced with a Mock.
+        protocol.transport.abortConnection.assert_called_once()  # type: ignore[attr-defined]
 
-    def _cleanup_error(self, deferred):
+    def _cleanup_error(self, deferred: "Deferred[int]") -> None:
         """Ensure that the error in the Deferred is handled gracefully."""
         called = [False]
 
-        def errback(f):
+        def errback(f: Failure) -> None:
             called[0] = True
 
         deferred.addErrback(errback)
         self.assertTrue(called[0])
 
-    def test_no_error(self):
+    def test_no_error(self) -> None:
         """A response that is NOT too large."""
         result, deferred, protocol = self._build_response()
 
@@ -76,7 +85,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self.assertEqual(result.getvalue(), b"12345")
         self.assertEqual(deferred.result, 5)
 
-    def test_too_large(self):
+    def test_too_large(self) -> None:
         """A response which is too large raises an exception."""
         result, deferred, protocol = self._build_response()
 
@@ -87,7 +96,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self._assert_error(deferred, protocol)
         self._cleanup_error(deferred)
 
-    def test_multiple_packets(self):
+    def test_multiple_packets(self) -> None:
         """Data should be accumulated through mutliple packets."""
         result, deferred, protocol = self._build_response()
 
@@ -100,7 +109,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self.assertEqual(result.getvalue(), b"1234")
         self.assertEqual(deferred.result, 4)
 
-    def test_additional_data(self):
+    def test_additional_data(self) -> None:
         """A connection can receive data after being closed."""
         result, deferred, protocol = self._build_response()
 
@@ -115,7 +124,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self._assert_error(deferred, protocol)
         self._cleanup_error(deferred)
 
-    def test_content_length(self):
+    def test_content_length(self) -> None:
         """The body shouldn't be read (at all) if the Content-Length header is too large."""
         result, deferred, protocol = self._build_response(length=10)
 
@@ -132,7 +141,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
 
 
 class BlacklistingAgentTest(TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.reactor, self.clock = get_clock()
 
         self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
@@ -151,7 +160,7 @@ class BlacklistingAgentTest(TestCase):
         self.ip_whitelist = IPSet([self.allowed_ip.decode()])
         self.ip_blacklist = IPSet(["5.0.0.0/8"])
 
-    def test_reactor(self):
+    def test_reactor(self) -> None:
         """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
         agent = Agent(
             BlacklistingReactorWrapper(
@@ -197,7 +206,7 @@ class BlacklistingAgentTest(TestCase):
             response = self.successResultOf(d)
             self.assertEqual(response.code, 200)
 
-    def test_agent(self):
+    def test_agent(self) -> None:
         """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
         agent = BlacklistingAgentWrapper(
             Agent(self.reactor),