summary refs log tree commit diff
path: root/tests/http/test_proxyagent.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/test_proxyagent.py')
-rw-r--r--tests/http/test_proxyagent.py134
1 files changed, 78 insertions, 56 deletions
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 2db77c6a73..cc175052ac 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -14,7 +14,7 @@
 import base64
 import logging
 import os
-from typing import Iterable, Optional
+from typing import List, Optional
 from unittest.mock import patch
 
 import treq
@@ -22,9 +22,13 @@ from netaddr import IPSet
 from parameterized import parameterized
 
 from twisted.internet import interfaces  # noqa: F401
-from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint
+from twisted.internet.endpoints import (
+    HostnameEndpoint,
+    _WrapperEndpoint,
+    _WrappingProtocol,
+)
 from twisted.internet.interfaces import IProtocol, IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
 from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.web.http import HTTPChannel
 
@@ -32,9 +36,14 @@ from synapse.http.client import BlacklistingReactorWrapper
 from synapse.http.connectproxyclient import ProxyCredentials
 from synapse.http.proxyagent import ProxyAgent, parse_proxy
 
-from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.http import (
+    TestServerTLSConnectionFactory,
+    dummy_address,
+    get_test_https_policy,
+)
 from tests.server import FakeTransport, ThreadedMemoryReactorClock
 from tests.unittest import TestCase
+from tests.utils import checked_cast
 
 logger = logging.getLogger(__name__)
 
@@ -183,7 +192,7 @@ class ProxyParserTests(TestCase):
         expected_hostname: bytes,
         expected_port: int,
         expected_credentials: Optional[bytes],
-    ):
+    ) -> None:
         """
         Tests that a given proxy URL will be broken into the components.
         Args:
@@ -209,7 +218,7 @@ class ProxyParserTests(TestCase):
 
 
 class MatrixFederationAgentTests(TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.reactor = ThreadedMemoryReactorClock()
 
     def _make_connection(
@@ -218,7 +227,7 @@ class MatrixFederationAgentTests(TestCase):
         server_factory: IProtocolFactory,
         ssl: bool = False,
         expected_sni: Optional[bytes] = None,
-        tls_sanlist: Optional[Iterable[bytes]] = None,
+        tls_sanlist: Optional[List[bytes]] = None,
     ) -> IProtocol:
         """Builds a test server, and completes the outgoing client connection
 
@@ -244,7 +253,8 @@ class MatrixFederationAgentTests(TestCase):
         if ssl:
             server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
 
-        server_protocol = server_factory.buildProtocol(None)
+        server_protocol = server_factory.buildProtocol(dummy_address)
+        assert server_protocol is not None
 
         # now, tell the client protocol factory to build the client protocol,
         # and wire the output of said protocol up to the server via
@@ -252,7 +262,8 @@ class MatrixFederationAgentTests(TestCase):
         #
         # Normally this would be done by the TCP socket code in Twisted, but we are
         # stubbing that out here.
-        client_protocol = client_factory.buildProtocol(None)
+        client_protocol = client_factory.buildProtocol(dummy_address)
+        assert client_protocol is not None
         client_protocol.makeConnection(
             FakeTransport(server_protocol, self.reactor, client_protocol)
         )
@@ -263,6 +274,7 @@ class MatrixFederationAgentTests(TestCase):
         )
 
         if ssl:
+            assert isinstance(server_protocol, TLSMemoryBIOProtocol)
             http_protocol = server_protocol.wrappedProtocol
             tls_connection = server_protocol._tlsConnection
         else:
@@ -288,7 +300,7 @@ class MatrixFederationAgentTests(TestCase):
         scheme: bytes,
         hostname: bytes,
         path: bytes,
-    ):
+    ) -> None:
         """Runs a test case for a direct connection not going through a proxy.
 
         Args:
@@ -319,6 +331,7 @@ class MatrixFederationAgentTests(TestCase):
             ssl=is_https,
             expected_sni=hostname if is_https else None,
         )
+        assert isinstance(http_server, HTTPChannel)
 
         # the FakeTransport is async, so we need to pump the reactor
         self.reactor.advance(0)
@@ -339,34 +352,34 @@ class MatrixFederationAgentTests(TestCase):
         body = self.successResultOf(treq.content(resp))
         self.assertEqual(body, b"result")
 
-    def test_http_request(self):
+    def test_http_request(self) -> None:
         agent = ProxyAgent(self.reactor)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
-    def test_https_request(self):
+    def test_https_request(self) -> None:
         agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
         self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
 
-    def test_http_request_use_proxy_empty_environment(self):
+    def test_http_request_use_proxy_empty_environment(self) -> None:
         agent = ProxyAgent(self.reactor, use_proxy=True)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
-    def test_http_request_via_uppercase_no_proxy(self):
+    def test_http_request_via_uppercase_no_proxy(self) -> None:
         agent = ProxyAgent(self.reactor, use_proxy=True)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
     @patch.dict(
         os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
     )
-    def test_http_request_via_no_proxy(self):
+    def test_http_request_via_no_proxy(self) -> None:
         agent = ProxyAgent(self.reactor, use_proxy=True)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
     @patch.dict(
         os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
     )
-    def test_https_request_via_no_proxy(self):
+    def test_https_request_via_no_proxy(self) -> None:
         agent = ProxyAgent(
             self.reactor,
             contextFactory=get_test_https_policy(),
@@ -375,12 +388,12 @@ class MatrixFederationAgentTests(TestCase):
         self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
-    def test_http_request_via_no_proxy_star(self):
+    def test_http_request_via_no_proxy_star(self) -> None:
         agent = ProxyAgent(self.reactor, use_proxy=True)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
     @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
-    def test_https_request_via_no_proxy_star(self):
+    def test_https_request_via_no_proxy_star(self) -> None:
         agent = ProxyAgent(
             self.reactor,
             contextFactory=get_test_https_policy(),
@@ -389,7 +402,7 @@ class MatrixFederationAgentTests(TestCase):
         self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
-    def test_http_request_via_proxy(self):
+    def test_http_request_via_proxy(self) -> None:
         """
         Tests that requests can be made through a proxy.
         """
@@ -401,7 +414,7 @@ class MatrixFederationAgentTests(TestCase):
         os.environ,
         {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"},
     )
-    def test_http_request_via_proxy_with_auth(self):
+    def test_http_request_via_proxy_with_auth(self) -> None:
         """
         Tests that authenticated requests can be made through a proxy.
         """
@@ -412,7 +425,7 @@ class MatrixFederationAgentTests(TestCase):
     @patch.dict(
         os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"}
     )
-    def test_http_request_via_https_proxy(self):
+    def test_http_request_via_https_proxy(self) -> None:
         self._do_http_request_via_proxy(
             expect_proxy_ssl=True, expected_auth_credentials=None
         )
@@ -424,13 +437,13 @@ class MatrixFederationAgentTests(TestCase):
             "no_proxy": "unused.com",
         },
     )
-    def test_http_request_via_https_proxy_with_auth(self):
+    def test_http_request_via_https_proxy_with_auth(self) -> None:
         self._do_http_request_via_proxy(
             expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
         )
 
     @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
-    def test_https_request_via_proxy(self):
+    def test_https_request_via_proxy(self) -> None:
         """Tests that TLS-encrypted requests can be made through a proxy"""
         self._do_https_request_via_proxy(
             expect_proxy_ssl=False, expected_auth_credentials=None
@@ -440,7 +453,7 @@ class MatrixFederationAgentTests(TestCase):
         os.environ,
         {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
     )
-    def test_https_request_via_proxy_with_auth(self):
+    def test_https_request_via_proxy_with_auth(self) -> None:
         """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
         self._do_https_request_via_proxy(
             expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
@@ -449,7 +462,7 @@ class MatrixFederationAgentTests(TestCase):
     @patch.dict(
         os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
     )
-    def test_https_request_via_https_proxy(self):
+    def test_https_request_via_https_proxy(self) -> None:
         """Tests that TLS-encrypted requests can be made through a proxy"""
         self._do_https_request_via_proxy(
             expect_proxy_ssl=True, expected_auth_credentials=None
@@ -459,7 +472,7 @@ class MatrixFederationAgentTests(TestCase):
         os.environ,
         {"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
     )
-    def test_https_request_via_https_proxy_with_auth(self):
+    def test_https_request_via_https_proxy_with_auth(self) -> None:
         """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
         self._do_https_request_via_proxy(
             expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
@@ -469,7 +482,7 @@ class MatrixFederationAgentTests(TestCase):
         self,
         expect_proxy_ssl: bool = False,
         expected_auth_credentials: Optional[bytes] = None,
-    ):
+    ) -> None:
         """Send a http request via an agent and check that it is correctly received at
             the proxy. The proxy can use either http or https.
         Args:
@@ -501,6 +514,7 @@ class MatrixFederationAgentTests(TestCase):
             tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
             expected_sni=b"proxy.com" if expect_proxy_ssl else None,
         )
+        assert isinstance(http_server, HTTPChannel)
 
         # the FakeTransport is async, so we need to pump the reactor
         self.reactor.advance(0)
@@ -542,7 +556,7 @@ class MatrixFederationAgentTests(TestCase):
         self,
         expect_proxy_ssl: bool = False,
         expected_auth_credentials: Optional[bytes] = None,
-    ):
+    ) -> None:
         """Send a https request via an agent and check that it is correctly received at
             the proxy and client. The proxy can use either http or https.
         Args:
@@ -606,10 +620,11 @@ class MatrixFederationAgentTests(TestCase):
         # now we make another test server to act as the upstream HTTP server.
         server_ssl_protocol = _wrap_server_factory_for_tls(
             _get_test_protocol_factory()
-        ).buildProtocol(None)
+        ).buildProtocol(dummy_address)
 
         # Tell the HTTP server to send outgoing traffic back via the proxy's transport.
         proxy_server_transport = proxy_server.transport
+        assert proxy_server_transport is not None
         server_ssl_protocol.makeConnection(proxy_server_transport)
 
         # ... and replace the protocol on the proxy's transport with the
@@ -629,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
         else:
             assert isinstance(proxy_server_transport, FakeTransport)
             client_protocol = proxy_server_transport.other
-            c2s_transport = client_protocol.transport
+            assert isinstance(client_protocol, Protocol)
+            c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
             c2s_transport.other = server_ssl_protocol
 
         self.reactor.advance(0)
@@ -644,6 +660,7 @@ class MatrixFederationAgentTests(TestCase):
 
         # now there should be a pending request
         http_server = server_ssl_protocol.wrappedProtocol
+        assert isinstance(http_server, HTTPChannel)
         self.assertEqual(len(http_server.requests), 1)
 
         request = http_server.requests[0]
@@ -667,7 +684,7 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(body, b"result")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
-    def test_http_request_via_proxy_with_blacklist(self):
+    def test_http_request_via_proxy_with_blacklist(self) -> None:
         # The blacklist includes the configured proxy IP.
         agent = ProxyAgent(
             BlacklistingReactorWrapper(
@@ -691,6 +708,7 @@ class MatrixFederationAgentTests(TestCase):
         http_server = self._make_connection(
             client_factory, _get_test_protocol_factory()
         )
+        assert isinstance(http_server, HTTPChannel)
 
         # the FakeTransport is async, so we need to pump the reactor
         self.reactor.advance(0)
@@ -712,7 +730,7 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(body, b"result")
 
     @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
-    def test_https_request_via_uppercase_proxy_with_blacklist(self):
+    def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None:
         # The blacklist includes the configured proxy IP.
         agent = ProxyAgent(
             BlacklistingReactorWrapper(
@@ -737,11 +755,17 @@ class MatrixFederationAgentTests(TestCase):
         proxy_server = self._make_connection(
             client_factory, _get_test_protocol_factory()
         )
+        assert isinstance(proxy_server, HTTPChannel)
 
         # fish the transports back out so that we can do the old switcheroo
-        s2c_transport = proxy_server.transport
-        client_protocol = s2c_transport.other
-        c2s_transport = client_protocol.transport
+        # To help mypy out with the various Protocols and wrappers and mocks, we do
+        # some explicit casting. Without the casts, we hit the bug I reported at
+        # https://github.com/Shoobx/mypy-zope/issues/91 .
+        # We also double-checked these casts at runtime (test-time) because I found it
+        # quite confusing to deduce these types in the first place!
+        s2c_transport = checked_cast(FakeTransport, proxy_server.transport)
+        client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other)
+        c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
 
         # the FakeTransport is async, so we need to pump the reactor
         self.reactor.advance(0)
@@ -762,8 +786,10 @@ class MatrixFederationAgentTests(TestCase):
 
         # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
         ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
-        ssl_protocol = ssl_factory.buildProtocol(None)
+        ssl_protocol = ssl_factory.buildProtocol(dummy_address)
+        assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
         http_server = ssl_protocol.wrappedProtocol
+        assert isinstance(http_server, HTTPChannel)
 
         ssl_protocol.makeConnection(
             FakeTransport(client_protocol, self.reactor, ssl_protocol)
@@ -797,39 +823,35 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(body, b"result")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
-    def test_proxy_with_no_scheme(self):
+    def test_proxy_with_no_scheme(self) -> None:
         http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
-        self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
-        self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
-        self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+        proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+        self.assertEqual(proxy_ep._hostStr, "proxy.com")
+        self.assertEqual(proxy_ep._port, 8888)
 
     @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
-    def test_proxy_with_unsupported_scheme(self):
+    def test_proxy_with_unsupported_scheme(self) -> None:
         with self.assertRaises(ValueError):
             ProxyAgent(self.reactor, use_proxy=True)
 
     @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
-    def test_proxy_with_http_scheme(self):
+    def test_proxy_with_http_scheme(self) -> None:
         http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
-        self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
-        self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
-        self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+        proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+        self.assertEqual(proxy_ep._hostStr, "proxy.com")
+        self.assertEqual(proxy_ep._port, 8888)
 
     @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
-    def test_proxy_with_https_scheme(self):
+    def test_proxy_with_https_scheme(self) -> None:
         https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
-        self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
-        self.assertEqual(
-            https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
-        )
-        self.assertEqual(
-            https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888
-        )
+        proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint)
+        self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com")
+        self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
 
 
 def _wrap_server_factory_for_tls(
-    factory: IProtocolFactory, sanlist: Iterable[bytes] = None
-) -> IProtocolFactory:
+    factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
+) -> TLSMemoryBIOFactory:
     """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
 
     The resultant factory will create a TLS server which presents a certificate
@@ -865,6 +887,6 @@ def _get_test_protocol_factory() -> IProtocolFactory:
     return server_factory
 
 
-def _log_request(request: str):
+def _log_request(request: str) -> None:
     """Implements Factory.log, which is expected by Request.finish"""
     logger.info(f"Completed request {request}")