diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 2d9b733be0..21ecb81c99 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -26,77 +26,96 @@ from tests.unittest import TestCase
class ReadBodyWithMaxSizeTests(TestCase):
- def setUp(self):
+ def _build_response(self, length=UNKNOWN_LENGTH):
"""Start reading the body, returns the response, result and proto"""
- response = Mock(length=UNKNOWN_LENGTH)
- self.result = BytesIO()
- self.deferred = read_body_with_max_size(response, self.result, 6)
+ response = Mock(length=length)
+ result = BytesIO()
+ deferred = read_body_with_max_size(response, result, 6)
# Fish the protocol out of the response.
- self.protocol = response.deliverBody.call_args[0][0]
- self.protocol.transport = Mock()
+ protocol = response.deliverBody.call_args[0][0]
+ protocol.transport = Mock()
- def _cleanup_error(self):
+ return result, deferred, protocol
+
+ def _assert_error(self, deferred, protocol):
+ """Ensure that the expected error is received."""
+ self.assertIsInstance(deferred.result, Failure)
+ self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
+ protocol.transport.abortConnection.assert_called_once()
+
+ def _cleanup_error(self, deferred):
"""Ensure that the error in the Deferred is handled gracefully."""
called = [False]
def errback(f):
called[0] = True
- self.deferred.addErrback(errback)
+ deferred.addErrback(errback)
self.assertTrue(called[0])
def test_no_error(self):
"""A response that is NOT too large."""
+ result, deferred, protocol = self._build_response()
# Start sending data.
- self.protocol.dataReceived(b"12345")
+ protocol.dataReceived(b"12345")
# Close the connection.
- self.protocol.connectionLost(Failure(ResponseDone()))
+ protocol.connectionLost(Failure(ResponseDone()))
- self.assertEqual(self.result.getvalue(), b"12345")
- self.assertEqual(self.deferred.result, 5)
+ self.assertEqual(result.getvalue(), b"12345")
+ self.assertEqual(deferred.result, 5)
def test_too_large(self):
"""A response which is too large raises an exception."""
+ result, deferred, protocol = self._build_response()
# Start sending data.
- self.protocol.dataReceived(b"1234567890")
- # Close the connection.
- self.protocol.connectionLost(Failure(ResponseDone()))
+ protocol.dataReceived(b"1234567890")
- self.assertEqual(self.result.getvalue(), b"1234567890")
- self.assertIsInstance(self.deferred.result, Failure)
- self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
- self._cleanup_error()
+ self.assertEqual(result.getvalue(), b"1234567890")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
def test_multiple_packets(self):
- """Data should be accummulated through mutliple packets."""
+ """Data should be accumulated through mutliple packets."""
+ result, deferred, protocol = self._build_response()
# Start sending data.
- self.protocol.dataReceived(b"12")
- self.protocol.dataReceived(b"34")
+ protocol.dataReceived(b"12")
+ protocol.dataReceived(b"34")
# Close the connection.
- self.protocol.connectionLost(Failure(ResponseDone()))
+ protocol.connectionLost(Failure(ResponseDone()))
- self.assertEqual(self.result.getvalue(), b"1234")
- self.assertEqual(self.deferred.result, 4)
+ self.assertEqual(result.getvalue(), b"1234")
+ self.assertEqual(deferred.result, 4)
def test_additional_data(self):
"""A connection can receive data after being closed."""
+ result, deferred, protocol = self._build_response()
# Start sending data.
- self.protocol.dataReceived(b"1234567890")
- self.assertIsInstance(self.deferred.result, Failure)
- self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
- self.protocol.transport.abortConnection.assert_called_once()
+ protocol.dataReceived(b"1234567890")
+ self._assert_error(deferred, protocol)
# More data might have come in.
- self.protocol.dataReceived(b"1234567890")
- # Close the connection.
- self.protocol.connectionLost(Failure(ResponseDone()))
+ protocol.dataReceived(b"1234567890")
+
+ self.assertEqual(result.getvalue(), b"1234567890")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
+
+ def test_content_length(self):
+ """The body shouldn't be read (at all) if the Content-Length header is too large."""
+ result, deferred, protocol = self._build_response(length=10)
+
+ # Deferred shouldn't be called yet.
+ self.assertFalse(deferred.called)
+
+ # Start sending data.
+ protocol.dataReceived(b"12345")
+ self._assert_error(deferred, protocol)
+ self._cleanup_error(deferred)
- self.assertEqual(self.result.getvalue(), b"1234567890")
- self.assertIsInstance(self.deferred.result, Failure)
- self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
- self._cleanup_error()
+ # The data is never consumed.
+ self.assertEqual(result.getvalue(), b"")
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 9a56e1c14a..505ffcd300 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import os
+from unittest.mock import patch
import treq
from netaddr import IPSet
@@ -100,22 +102,36 @@ class MatrixFederationAgentTests(TestCase):
return http_protocol
- def test_http_request(self):
- agent = ProxyAgent(self.reactor)
+ def _test_request_direct_connection(self, agent, scheme, hostname, path):
+ """Runs a test case for a direct connection not going through a proxy.
- self.reactor.lookups["test.com"] = "1.2.3.4"
- d = agent.request(b"GET", b"http://test.com")
+ Args:
+ agent (ProxyAgent): the proxy agent being tested
+
+ scheme (bytes): expected to be either "http" or "https"
+
+ hostname (bytes): the hostname to connect to in the test
+
+ path (bytes): the path to connect to in the test
+ """
+ is_https = scheme == b"https"
+
+ self.reactor.lookups[hostname.decode()] = "1.2.3.4"
+ d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path)
# there should be a pending TCP connection
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 80)
+ self.assertEqual(port, 443 if is_https else 80)
# make a test server, and wire up the client
http_server = self._make_connection(
- client_factory, _get_test_protocol_factory()
+ client_factory,
+ _get_test_protocol_factory(),
+ ssl=is_https,
+ expected_sni=hostname if is_https else None,
)
# the FakeTransport is async, so we need to pump the reactor
@@ -126,8 +142,8 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
- self.assertEqual(request.path, b"/")
- self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+ self.assertEqual(request.path, b"/" + path)
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname])
request.write(b"result")
request.finish()
@@ -137,48 +153,58 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
+ def test_http_request(self):
+ agent = ProxyAgent(self.reactor)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
+
def test_https_request(self):
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+ self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
- self.reactor.lookups["test.com"] = "1.2.3.4"
- d = agent.request(b"GET", b"https://test.com/abc")
+ def test_http_request_use_proxy_empty_environment(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- # there should be a pending TCP connection
- clients = self.reactor.tcpClients
- self.assertEqual(len(clients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = clients[0]
- self.assertEqual(host, "1.2.3.4")
- self.assertEqual(port, 443)
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
+ def test_http_request_via_uppercase_no_proxy(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- # make a test server, and wire up the client
- http_server = self._make_connection(
- client_factory,
- _get_test_protocol_factory(),
- ssl=True,
- expected_sni=b"test.com",
- )
-
- # the FakeTransport is async, so we need to pump the reactor
- self.reactor.advance(0)
-
- # now there should be a pending request
- self.assertEqual(len(http_server.requests), 1)
+ @patch.dict(
+ os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
+ )
+ def test_http_request_via_no_proxy(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- request = http_server.requests[0]
- self.assertEqual(request.method, b"GET")
- self.assertEqual(request.path, b"/abc")
- self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
- request.write(b"result")
- request.finish()
+ @patch.dict(
+ os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
+ )
+ def test_https_request_via_no_proxy(self):
+ agent = ProxyAgent(
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ use_proxy=True,
+ )
+ self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
- self.reactor.advance(0)
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
+ def test_http_request_via_no_proxy_star(self):
+ agent = ProxyAgent(self.reactor, use_proxy=True)
+ self._test_request_direct_connection(agent, b"http", b"test.com", b"")
- resp = self.successResultOf(d)
- body = self.successResultOf(treq.content(resp))
- self.assertEqual(body, b"result")
+ @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
+ def test_https_request_via_no_proxy_star(self):
+ agent = ProxyAgent(
+ self.reactor,
+ contextFactory=get_test_https_policy(),
+ use_proxy=True,
+ )
+ self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
def test_http_request_via_proxy(self):
- agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+ agent = ProxyAgent(self.reactor, use_proxy=True)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
d = agent.request(b"GET", b"http://test.com")
@@ -214,11 +240,12 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
+ @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
def test_https_request_via_proxy(self):
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
- https_proxy=b"proxy.com",
+ use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
@@ -294,6 +321,7 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
+ @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_http_request_via_proxy_with_blacklist(self):
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
@@ -301,7 +329,7 @@ class MatrixFederationAgentTests(TestCase):
self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
),
self.reactor,
- http_proxy=b"proxy.com:8888",
+ use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
@@ -338,7 +366,8 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
- def test_https_request_via_proxy_with_blacklist(self):
+ @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
+ def test_https_request_via_uppercase_proxy_with_blacklist(self):
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
BlacklistingReactorWrapper(
@@ -346,7 +375,7 @@ class MatrixFederationAgentTests(TestCase):
),
self.reactor,
contextFactory=get_test_https_policy(),
- https_proxy=b"proxy.com",
+ use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|