summary refs log tree commit diff
path: root/synapse/http/connectproxyclient.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/connectproxyclient.py')
-rw-r--r--synapse/http/connectproxyclient.py68
1 files changed, 47 insertions, 21 deletions
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 17e1c5abb1..c577142268 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -12,8 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import base64
 import logging
+from typing import Optional
 
+import attr
 from zope.interface import implementer
 
 from twisted.internet import defer, protocol
@@ -21,7 +24,6 @@ from twisted.internet.error import ConnectError
 from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
 from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
 from twisted.web import http
-from twisted.web.http_headers import Headers
 
 logger = logging.getLogger(__name__)
 
@@ -30,6 +32,22 @@ class ProxyConnectError(ConnectError):
     pass
 
 
+@attr.s
+class ProxyCredentials:
+    username_password = attr.ib(type=bytes)
+
+    def as_proxy_authorization_value(self) -> bytes:
+        """
+        Return the value for a Proxy-Authorization header (i.e. 'Basic abdef==').
+
+        Returns:
+            A transformation of the authentication string the encoded value for
+            a Proxy-Authorization header.
+        """
+        # Encode as base64 and prepend the authorization type
+        return b"Basic " + base64.encodebytes(self.username_password)
+
+
 @implementer(IStreamClientEndpoint)
 class HTTPConnectProxyEndpoint:
     """An Endpoint implementation which will send a CONNECT request to an http proxy
@@ -46,7 +64,7 @@ class HTTPConnectProxyEndpoint:
         proxy_endpoint: the endpoint to use to connect to the proxy
         host: hostname that we want to CONNECT to
         port: port that we want to connect to
-        headers: Extra HTTP headers to include in the CONNECT request
+        proxy_creds: credentials to authenticate at proxy
     """
 
     def __init__(
@@ -55,20 +73,20 @@ class HTTPConnectProxyEndpoint:
         proxy_endpoint: IStreamClientEndpoint,
         host: bytes,
         port: int,
-        headers: Headers,
+        proxy_creds: Optional[ProxyCredentials],
     ):
         self._reactor = reactor
         self._proxy_endpoint = proxy_endpoint
         self._host = host
         self._port = port
-        self._headers = headers
+        self._proxy_creds = proxy_creds
 
     def __repr__(self):
         return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 
     def connect(self, protocolFactory: ClientFactory):
         f = HTTPProxiedClientFactory(
-            self._host, self._port, protocolFactory, self._headers
+            self._host, self._port, protocolFactory, self._proxy_creds
         )
         d = self._proxy_endpoint.connect(f)
         # once the tcp socket connects successfully, we need to wait for the
@@ -87,7 +105,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         dst_host: hostname that we want to CONNECT to
         dst_port: port that we want to connect to
         wrapped_factory: The original Factory
-        headers: Extra HTTP headers to include in the CONNECT request
+        proxy_creds: credentials to authenticate at proxy
     """
 
     def __init__(
@@ -95,12 +113,12 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         dst_host: bytes,
         dst_port: int,
         wrapped_factory: ClientFactory,
-        headers: Headers,
+        proxy_creds: Optional[ProxyCredentials],
     ):
         self.dst_host = dst_host
         self.dst_port = dst_port
         self.wrapped_factory = wrapped_factory
-        self.headers = headers
+        self.proxy_creds = proxy_creds
         self.on_connection = defer.Deferred()
 
     def startedConnecting(self, connector):
@@ -114,7 +132,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
             self.dst_port,
             wrapped_protocol,
             self.on_connection,
-            self.headers,
+            self.proxy_creds,
         )
 
     def clientConnectionFailed(self, connector, reason):
@@ -145,7 +163,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         connected_deferred: a Deferred which will be callbacked with
             wrapped_protocol when the CONNECT completes
 
-        headers: Extra HTTP headers to include in the CONNECT request
+        proxy_creds: credentials to authenticate at proxy
     """
 
     def __init__(
@@ -154,16 +172,16 @@ class HTTPConnectProtocol(protocol.Protocol):
         port: int,
         wrapped_protocol: Protocol,
         connected_deferred: defer.Deferred,
-        headers: Headers,
+        proxy_creds: Optional[ProxyCredentials],
     ):
         self.host = host
         self.port = port
         self.wrapped_protocol = wrapped_protocol
         self.connected_deferred = connected_deferred
-        self.headers = headers
+        self.proxy_creds = proxy_creds
 
         self.http_setup_client = HTTPConnectSetupClient(
-            self.host, self.port, self.headers
+            self.host, self.port, self.proxy_creds
         )
         self.http_setup_client.on_connected.addCallback(self.proxyConnected)
 
@@ -205,30 +223,38 @@ class HTTPConnectSetupClient(http.HTTPClient):
     Args:
         host: The hostname to send in the CONNECT message
         port: The port to send in the CONNECT message
-        headers: Extra headers to send with the CONNECT message
+        proxy_creds: credentials to authenticate at proxy
     """
 
-    def __init__(self, host: bytes, port: int, headers: Headers):
+    def __init__(
+        self,
+        host: bytes,
+        port: int,
+        proxy_creds: Optional[ProxyCredentials],
+    ):
         self.host = host
         self.port = port
-        self.headers = headers
+        self.proxy_creds = proxy_creds
         self.on_connected = defer.Deferred()
 
     def connectionMade(self):
         logger.debug("Connected to proxy, sending CONNECT")
         self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
 
-        # Send any additional specified headers
-        for name, values in self.headers.getAllRawHeaders():
-            for value in values:
-                self.sendHeader(name, value)
+        # Determine whether we need to set Proxy-Authorization headers
+        if self.proxy_creds:
+            # Set a Proxy-Authorization header
+            self.sendHeader(
+                b"Proxy-Authorization",
+                self.proxy_creds.as_proxy_authorization_value(),
+            )
 
         self.endHeaders()
 
     def handleStatus(self, version: bytes, status: bytes, message: bytes):
         logger.debug("Got Status: %s %s %s", status, message, version)
         if status != b"200":
-            raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
+            raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
 
     def handleEndHeaders(self):
         logger.debug("End Headers")