summary refs log tree commit diff
path: root/tests/http/test_matrixfederationclient.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/test_matrixfederationclient.py')
-rw-r--r--tests/http/test_matrixfederationclient.py53
1 files changed, 31 insertions, 22 deletions
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index be9eaf34e8..fdd22a8e94 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -11,16 +11,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from typing import Generator
 from unittest.mock import Mock
 
 from netaddr import IPSet
 from parameterized import parameterized
 
 from twisted.internet import defer
-from twisted.internet.defer import TimeoutError
+from twisted.internet.defer import Deferred, TimeoutError
 from twisted.internet.error import ConnectingCancelledError, DNSLookupError
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
 from twisted.web.client import ResponseNeverReceived
 from twisted.web.http import HTTPChannel
 
@@ -30,34 +30,43 @@ from synapse.http.matrixfederationclient import (
     MatrixFederationHttpClient,
     MatrixFederationRequest,
 )
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import (
+    SENTINEL_CONTEXT,
+    LoggingContext,
+    LoggingContextOrSentinel,
+    current_context,
+)
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.server import FakeTransport
 from tests.unittest import HomeserverTestCase
 
 
-def check_logcontext(context):
+def check_logcontext(context: LoggingContextOrSentinel) -> None:
     current = current_context()
     if current is not context:
         raise AssertionError("Expected logcontext %s but was %s" % (context, current))
 
 
 class FederationClientTests(HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
         return hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.cl = MatrixFederationHttpClient(self.hs, None)
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
-    def test_client_get(self):
+    def test_client_get(self) -> None:
         """
         happy-path test of a GET request
         """
 
         @defer.inlineCallbacks
-        def do_request():
+        def do_request() -> Generator["Deferred[object]", object, object]:
             with LoggingContext("one") as context:
                 fetch_d = defer.ensureDeferred(
                     self.cl.get_json("testserv:8008", "foo/bar")
@@ -119,7 +128,7 @@ class FederationClientTests(HomeserverTestCase):
         # check the response is as expected
         self.assertEqual(res, {"a": 1})
 
-    def test_dns_error(self):
+    def test_dns_error(self) -> None:
         """
         If the DNS lookup returns an error, it will bubble up.
         """
@@ -132,7 +141,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value, RequestSendFailed)
         self.assertIsInstance(f.value.inner_exception, DNSLookupError)
 
-    def test_client_connection_refused(self):
+    def test_client_connection_refused(self) -> None:
         d = defer.ensureDeferred(
             self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
         )
@@ -156,7 +165,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value, RequestSendFailed)
         self.assertIs(f.value.inner_exception, e)
 
-    def test_client_never_connect(self):
+    def test_client_never_connect(self) -> None:
         """
         If the HTTP request is not connected and is timed out, it'll give a
         ConnectingCancelledError or TimeoutError.
@@ -188,7 +197,7 @@ class FederationClientTests(HomeserverTestCase):
             f.value.inner_exception, (ConnectingCancelledError, TimeoutError)
         )
 
-    def test_client_connect_no_response(self):
+    def test_client_connect_no_response(self) -> None:
         """
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
@@ -222,7 +231,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value, RequestSendFailed)
         self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
 
-    def test_client_ip_range_blacklist(self):
+    def test_client_ip_range_blacklist(self) -> None:
         """Ensure that Synapse does not try to connect to blacklisted IPs"""
 
         # Set up the ip_range blacklist
@@ -292,7 +301,7 @@ class FederationClientTests(HomeserverTestCase):
         f = self.failureResultOf(d, RequestSendFailed)
         self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
 
-    def test_client_gets_headers(self):
+    def test_client_gets_headers(self) -> None:
         """
         Once the client gets the headers, _request returns successfully.
         """
@@ -319,7 +328,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertEqual(r.code, 200)
 
     @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
-    def test_timeout_reading_body(self, method_name: str):
+    def test_timeout_reading_body(self, method_name: str) -> None:
         """
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a RequestSendFailed with can_retry.
@@ -351,7 +360,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertTrue(f.value.can_retry)
         self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
 
-    def test_client_requires_trailing_slashes(self):
+    def test_client_requires_trailing_slashes(self) -> None:
         """
         If a connection is made to a client but the client rejects it due to
         requiring a trailing slash. We need to retry the request with a
@@ -405,7 +414,7 @@ class FederationClientTests(HomeserverTestCase):
         r = self.successResultOf(d)
         self.assertEqual(r, {})
 
-    def test_client_does_not_retry_on_400_plus(self):
+    def test_client_does_not_retry_on_400_plus(self) -> None:
         """
         Another test for trailing slashes but now test that we don't retry on
         trailing slashes on a non-400/M_UNRECOGNIZED response.
@@ -450,7 +459,7 @@ class FederationClientTests(HomeserverTestCase):
         # We should get a 404 failure response
         self.failureResultOf(d)
 
-    def test_client_sends_body(self):
+    def test_client_sends_body(self) -> None:
         defer.ensureDeferred(
             self.cl.post_json(
                 "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
@@ -474,7 +483,7 @@ class FederationClientTests(HomeserverTestCase):
         content = request.content.read()
         self.assertEqual(content, b'{"a":"b"}')
 
-    def test_closes_connection(self):
+    def test_closes_connection(self) -> None:
         """Check that the client closes unused HTTP connections"""
         d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
 
@@ -514,7 +523,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertTrue(conn.disconnecting)
 
     @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
-    def test_json_error(self, return_value):
+    def test_json_error(self, return_value: bytes) -> None:
         """
         Test what happens if invalid JSON is returned from the remote endpoint.
         """
@@ -560,7 +569,7 @@ class FederationClientTests(HomeserverTestCase):
         f = self.failureResultOf(test_d)
         self.assertIsInstance(f.value, RequestSendFailed)
 
-    def test_too_big(self):
+    def test_too_big(self) -> None:
         """
         Test what happens if a huge response is returned from the remote endpoint.
         """