summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/http/test_proxyagent.py130
1 files changed, 130 insertions, 0 deletions
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 22abf76515..9a56e1c14a 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -15,12 +15,14 @@
 import logging
 
 import treq
+from netaddr import IPSet
 
 from twisted.internet import interfaces  # noqa: F401
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.web.http import HTTPChannel
 
+from synapse.http.client import BlacklistingReactorWrapper
 from synapse.http.proxyagent import ProxyAgent
 
 from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
@@ -292,6 +294,134 @@ class MatrixFederationAgentTests(TestCase):
         body = self.successResultOf(treq.content(resp))
         self.assertEqual(body, b"result")
 
+    def test_http_request_via_proxy_with_blacklist(self):
+        # The blacklist includes the configured proxy IP.
+        agent = ProxyAgent(
+            BlacklistingReactorWrapper(
+                self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+            ),
+            self.reactor,
+            http_proxy=b"proxy.com:8888",
+        )
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"http://test.com")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 8888)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"http://test.com")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+    def test_https_request_via_proxy_with_blacklist(self):
+        # The blacklist includes the configured proxy IP.
+        agent = ProxyAgent(
+            BlacklistingReactorWrapper(
+                self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+            ),
+            self.reactor,
+            contextFactory=get_test_https_policy(),
+            https_proxy=b"proxy.com",
+        )
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"https://test.com/abc")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 1080)
+
+        # make a test HTTP server, and wire up the client
+        proxy_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # fish the transports back out so that we can do the old switcheroo
+        s2c_transport = proxy_server.transport
+        client_protocol = s2c_transport.other
+        c2s_transport = client_protocol.transport
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending CONNECT request
+        self.assertEqual(len(proxy_server.requests), 1)
+
+        request = proxy_server.requests[0]
+        self.assertEqual(request.method, b"CONNECT")
+        self.assertEqual(request.path, b"test.com:443")
+
+        # tell the proxy server not to close the connection
+        proxy_server.persistent = True
+
+        # this just stops the http Request trying to do a chunked response
+        # request.setHeader(b"Content-Length", b"0")
+        request.finish()
+
+        # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+        ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+        ssl_protocol = ssl_factory.buildProtocol(None)
+        http_server = ssl_protocol.wrappedProtocol
+
+        ssl_protocol.makeConnection(
+            FakeTransport(client_protocol, self.reactor, ssl_protocol)
+        )
+        c2s_transport.other = ssl_protocol
+
+        self.reactor.advance(0)
+
+        server_name = ssl_protocol._tlsConnection.get_servername()
+        expected_sni = b"test.com"
+        self.assertEqual(
+            server_name,
+            expected_sni,
+            "Expected SNI %s but got %s" % (expected_sni, server_name),
+        )
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/abc")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
 
 def _wrap_server_factory_for_tls(factory, sanlist=None):
     """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory