summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9657.feature1
-rw-r--r--synapse/http/connectproxyclient.py96
-rw-r--r--synapse/http/proxyagent.py81
-rw-r--r--tests/http/test_proxyagent.py40
4 files changed, 184 insertions, 34 deletions
diff --git a/changelog.d/9657.feature b/changelog.d/9657.feature
new file mode 100644
index 0000000000..c56a615a8b
--- /dev/null
+++ b/changelog.d/9657.feature
@@ -0,0 +1 @@
+Add support for credentials for proxy authentication in the `HTTPS_PROXY` environment variable.
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 856e28454f..b797e3ce80 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -19,9 +19,10 @@ from zope.interface import implementer
 
 from twisted.internet import defer, protocol
 from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.internet.protocol import connectionDone
+from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
 from twisted.web import http
+from twisted.web.http_headers import Headers
 
 logger = logging.getLogger(__name__)
 
@@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint:
 
     Args:
         reactor: the Twisted reactor to use for the connection
-        proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
-            proxy
-        host (bytes): hostname that we want to CONNECT to
-        port (int): port that we want to connect to
+        proxy_endpoint: the endpoint to use to connect to the proxy
+        host: hostname that we want to CONNECT to
+        port: port that we want to connect to
+        headers: Extra HTTP headers to include in the CONNECT request
     """
 
-    def __init__(self, reactor, proxy_endpoint, host, port):
+    def __init__(
+        self,
+        reactor: IReactorCore,
+        proxy_endpoint: IStreamClientEndpoint,
+        host: bytes,
+        port: int,
+        headers: Headers,
+    ):
         self._reactor = reactor
         self._proxy_endpoint = proxy_endpoint
         self._host = host
         self._port = port
+        self._headers = headers
 
     def __repr__(self):
         return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 
-    def connect(self, protocolFactory):
-        f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+    def connect(self, protocolFactory: ClientFactory):
+        f = HTTPProxiedClientFactory(
+            self._host, self._port, protocolFactory, self._headers
+        )
         d = self._proxy_endpoint.connect(f)
         # once the tcp socket connects successfully, we need to wait for the
         # CONNECT to complete.
@@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
     HTTP Protocol object and run the rest of the connection.
 
     Args:
-        dst_host (bytes): hostname that we want to CONNECT to
-        dst_port (int): port that we want to connect to
-        wrapped_factory (protocol.ClientFactory): The original Factory
+        dst_host: hostname that we want to CONNECT to
+        dst_port: port that we want to connect to
+        wrapped_factory: The original Factory
+        headers: Extra HTTP headers to include in the CONNECT request
     """
 
-    def __init__(self, dst_host, dst_port, wrapped_factory):
+    def __init__(
+        self,
+        dst_host: bytes,
+        dst_port: int,
+        wrapped_factory: ClientFactory,
+        headers: Headers,
+    ):
         self.dst_host = dst_host
         self.dst_port = dst_port
         self.wrapped_factory = wrapped_factory
+        self.headers = headers
         self.on_connection = defer.Deferred()
 
     def startedConnecting(self, connector):
@@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
 
         return HTTPConnectProtocol(
-            self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+            self.dst_host,
+            self.dst_port,
+            wrapped_protocol,
+            self.on_connection,
+            self.headers,
         )
 
     def clientConnectionFailed(self, connector, reason):
@@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol):
     """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
 
     Args:
-        host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+        host: The original HTTP(s) hostname or IPv4 or IPv6 address literal
             to put in the CONNECT request
 
-        port (int): The original HTTP(s) port to put in the CONNECT request
+        port: The original HTTP(s) port to put in the CONNECT request
 
-        wrapped_protocol (interfaces.IProtocol): the original protocol (probably
-            HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+        wrapped_protocol: the original protocol (probably HTTPChannel or
+            TLSMemoryBIOProtocol, but could be anything really)
 
-        connected_deferred (Deferred): a Deferred which will be callbacked with
+        connected_deferred: a Deferred which will be callbacked with
             wrapped_protocol when the CONNECT completes
+
+        headers: Extra HTTP headers to include in the CONNECT request
     """
 
-    def __init__(self, host, port, wrapped_protocol, connected_deferred):
+    def __init__(
+        self,
+        host: bytes,
+        port: int,
+        wrapped_protocol: Protocol,
+        connected_deferred: defer.Deferred,
+        headers: Headers,
+    ):
         self.host = host
         self.port = port
         self.wrapped_protocol = wrapped_protocol
         self.connected_deferred = connected_deferred
-        self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+        self.headers = headers
+
+        self.http_setup_client = HTTPConnectSetupClient(
+            self.host, self.port, self.headers
+        )
         self.http_setup_client.on_connected.addCallback(self.proxyConnected)
 
     def connectionMade(self):
@@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if buf:
             self.wrapped_protocol.dataReceived(buf)
 
-    def dataReceived(self, data):
+    def dataReceived(self, data: bytes):
         # if we've set up the HTTP protocol, we can send the data there
         if self.wrapped_protocol.connected:
             return self.wrapped_protocol.dataReceived(data)
@@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient):
     """HTTPClient protocol to send a CONNECT message for proxies and read the response.
 
     Args:
-        host (bytes): The hostname to send in the CONNECT message
-        port (int): The port to send in the CONNECT message
+        host: The hostname to send in the CONNECT message
+        port: The port to send in the CONNECT message
+        headers: Extra headers to send with the CONNECT message
     """
 
-    def __init__(self, host, port):
+    def __init__(self, host: bytes, port: int, headers: Headers):
         self.host = host
         self.port = port
+        self.headers = headers
         self.on_connected = defer.Deferred()
 
     def connectionMade(self):
         logger.debug("Connected to proxy, sending CONNECT")
         self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+
+        # Send any additional specified headers
+        for name, values in self.headers.getAllRawHeaders():
+            for value in values:
+                self.sendHeader(name, value)
+
         self.endHeaders()
 
-    def handleStatus(self, version, status, message):
+    def handleStatus(self, version: bytes, status: bytes, message: bytes):
         logger.debug("Got Status: %s %s %s", status, message, version)
         if status != b"200":
             raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 3d553ae236..16ec850064 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -12,10 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import base64
 import logging
 import re
+from typing import Optional, Tuple
 from urllib.request import getproxies_environment, proxy_bypass_environment
 
+import attr
 from zope.interface import implementer
 
 from twisted.internet import defer
@@ -23,6 +26,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.python.failure import Failure
 from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
 from twisted.web.error import SchemeNotSupported
+from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent
 
 from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
@@ -32,6 +36,22 @@ logger = logging.getLogger(__name__)
 _VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
 
 
+@attr.s
+class ProxyCredentials:
+    username_password = attr.ib(type=bytes)
+
+    def as_proxy_authorization_value(self) -> bytes:
+        """
+        Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
+
+        Returns:
+            A transformation of the authentication string the encoded value for
+            a Proxy-Authorization header.
+        """
+        # Encode as base64 and prepend the authorization type
+        return b"Basic " + base64.encodebytes(self.username_password)
+
+
 @implementer(IAgent)
 class ProxyAgent(_AgentBase):
     """An Agent implementation which will use an HTTP proxy if one was requested
@@ -96,6 +116,9 @@ class ProxyAgent(_AgentBase):
             https_proxy = proxies["https"].encode() if "https" in proxies else None
             no_proxy = proxies["no"] if "no" in proxies else None
 
+        # Parse credentials from https proxy connection string if present
+        self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
+
         self.http_proxy_endpoint = _http_proxy_endpoint(
             http_proxy, self.proxy_reactor, **self._endpoint_kwargs
         )
@@ -175,11 +198,22 @@ class ProxyAgent(_AgentBase):
             and self.https_proxy_endpoint
             and not should_skip_proxy
         ):
+            connect_headers = Headers()
+
+            # Determine whether we need to set Proxy-Authorization headers
+            if self.https_proxy_creds:
+                # Set a Proxy-Authorization header
+                connect_headers.addRawHeader(
+                    b"Proxy-Authorization",
+                    self.https_proxy_creds.as_proxy_authorization_value(),
+                )
+
             endpoint = HTTPConnectProxyEndpoint(
                 self.proxy_reactor,
                 self.https_proxy_endpoint,
                 parsed_uri.host,
                 parsed_uri.port,
+                headers=connect_headers,
             )
         else:
             # not using a proxy
@@ -208,12 +242,16 @@ class ProxyAgent(_AgentBase):
         )
 
 
-def _http_proxy_endpoint(proxy, reactor, **kwargs):
+def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs):
     """Parses an http proxy setting and returns an endpoint for the proxy
 
     Args:
-        proxy (bytes|None):  the proxy setting
+        proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>]
+            Note that compared to other apps, this function currently lacks support
+            for specifying a protocol schema (i.e. protocol://...).
+
         reactor: reactor to be used to connect to the proxy
+
         kwargs: other args to be passed to HostnameEndpoint
 
     Returns:
@@ -223,16 +261,43 @@ def _http_proxy_endpoint(proxy, reactor, **kwargs):
     if proxy is None:
         return None
 
-    # currently we only support hostname:port. Some apps also support
-    # protocol://<host>[:port], which allows a way of requiring a TLS connection to the
-    # proxy.
-
+    # Parse the connection string
     host, port = parse_host_port(proxy, default_port=1080)
     return HostnameEndpoint(reactor, host, port, **kwargs)
 
 
-def parse_host_port(hostport, default_port=None):
-    # could have sworn we had one of these somewhere else...
+def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]:
+    """
+    Parses the username and password from a proxy declaration e.g
+    username:password@hostname:port.
+
+    Args:
+        proxy: The proxy connection string.
+
+    Returns
+        An instance of ProxyCredentials and the proxy connection string with any credentials
+        stripped, i.e u:p@host:port -> host:port. If no credentials were found, the
+        ProxyCredentials instance is replaced with None.
+    """
+    if proxy and b"@" in proxy:
+        # We use rsplit here as the password could contain an @ character
+        credentials, proxy_without_credentials = proxy.rsplit(b"@", 1)
+        return ProxyCredentials(credentials), proxy_without_credentials
+
+    return None, proxy
+
+
+def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]:
+    """
+    Parse the hostname and port from a proxy connection byte string.
+
+    Args:
+        hostport: The proxy connection string. Must be in the form 'host[:port]'.
+        default_port: The default port to return if one is not found in `hostport`.
+
+    Returns:
+        A tuple containing the hostname and port. Uses `default_port` if one was not found.
+    """
     if b":" in hostport:
         host, port = hostport.rsplit(b":", 1)
         try:
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 505ffcd300..3ea8b5bec7 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -12,8 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import base64
 import logging
 import os
+from typing import Optional
 from unittest.mock import patch
 
 import treq
@@ -242,6 +244,21 @@ class MatrixFederationAgentTests(TestCase):
 
     @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
     def test_https_request_via_proxy(self):
+        """Tests that TLS-encrypted requests can be made through a proxy"""
+        self._do_https_request_via_proxy(auth_credentials=None)
+
+    @patch.dict(
+        os.environ,
+        {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
+    )
+    def test_https_request_via_proxy_with_auth(self):
+        """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
+        self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
+
+    def _do_https_request_via_proxy(
+        self,
+        auth_credentials: Optional[str] = None,
+    ):
         agent = ProxyAgent(
             self.reactor,
             contextFactory=get_test_https_policy(),
@@ -278,6 +295,22 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(request.method, b"CONNECT")
         self.assertEqual(request.path, b"test.com:443")
 
+        # Check whether auth credentials have been supplied to the proxy
+        proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+            b"Proxy-Authorization"
+        )
+
+        if auth_credentials is not None:
+            # Compute the correct header value for Proxy-Authorization
+            encoded_credentials = base64.b64encode(b"bob:pinkponies")
+            expected_header_value = b"Basic " + encoded_credentials
+
+            # Validate the header's value
+            self.assertIn(expected_header_value, proxy_auth_header_values)
+        else:
+            # Check that the Proxy-Authorization header has not been supplied to the proxy
+            self.assertIsNone(proxy_auth_header_values)
+
         # tell the proxy server not to close the connection
         proxy_server.persistent = True
 
@@ -312,6 +345,13 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(request.method, b"GET")
         self.assertEqual(request.path, b"/abc")
         self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+
+        # Check that the destination server DID NOT receive proxy credentials
+        proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+            b"Proxy-Authorization"
+        )
+        self.assertIsNone(proxy_auth_header_values)
+
         request.write(b"result")
         request.finish()