summary refs log tree commit diff
path: root/synapse/http/connectproxyclient.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/connectproxyclient.py')
-rw-r--r--synapse/http/connectproxyclient.py96
1 files changed, 70 insertions, 26 deletions
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 856e28454f..b797e3ce80 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -19,9 +19,10 @@ from zope.interface import implementer
 
 from twisted.internet import defer, protocol
 from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.internet.protocol import connectionDone
+from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
 from twisted.web import http
+from twisted.web.http_headers import Headers
 
 logger = logging.getLogger(__name__)
 
@@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint:
 
     Args:
         reactor: the Twisted reactor to use for the connection
-        proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
-            proxy
-        host (bytes): hostname that we want to CONNECT to
-        port (int): port that we want to connect to
+        proxy_endpoint: the endpoint to use to connect to the proxy
+        host: hostname that we want to CONNECT to
+        port: port that we want to connect to
+        headers: Extra HTTP headers to include in the CONNECT request
     """
 
-    def __init__(self, reactor, proxy_endpoint, host, port):
+    def __init__(
+        self,
+        reactor: IReactorCore,
+        proxy_endpoint: IStreamClientEndpoint,
+        host: bytes,
+        port: int,
+        headers: Headers,
+    ):
         self._reactor = reactor
         self._proxy_endpoint = proxy_endpoint
         self._host = host
         self._port = port
+        self._headers = headers
 
     def __repr__(self):
         return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 
-    def connect(self, protocolFactory):
-        f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+    def connect(self, protocolFactory: ClientFactory):
+        f = HTTPProxiedClientFactory(
+            self._host, self._port, protocolFactory, self._headers
+        )
         d = self._proxy_endpoint.connect(f)
         # once the tcp socket connects successfully, we need to wait for the
         # CONNECT to complete.
@@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
     HTTP Protocol object and run the rest of the connection.
 
     Args:
-        dst_host (bytes): hostname that we want to CONNECT to
-        dst_port (int): port that we want to connect to
-        wrapped_factory (protocol.ClientFactory): The original Factory
+        dst_host: hostname that we want to CONNECT to
+        dst_port: port that we want to connect to
+        wrapped_factory: The original Factory
+        headers: Extra HTTP headers to include in the CONNECT request
     """
 
-    def __init__(self, dst_host, dst_port, wrapped_factory):
+    def __init__(
+        self,
+        dst_host: bytes,
+        dst_port: int,
+        wrapped_factory: ClientFactory,
+        headers: Headers,
+    ):
         self.dst_host = dst_host
         self.dst_port = dst_port
         self.wrapped_factory = wrapped_factory
+        self.headers = headers
         self.on_connection = defer.Deferred()
 
     def startedConnecting(self, connector):
@@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
 
         return HTTPConnectProtocol(
-            self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+            self.dst_host,
+            self.dst_port,
+            wrapped_protocol,
+            self.on_connection,
+            self.headers,
         )
 
     def clientConnectionFailed(self, connector, reason):
@@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol):
     """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
 
     Args:
-        host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+        host: The original HTTP(s) hostname or IPv4 or IPv6 address literal
             to put in the CONNECT request
 
-        port (int): The original HTTP(s) port to put in the CONNECT request
+        port: The original HTTP(s) port to put in the CONNECT request
 
-        wrapped_protocol (interfaces.IProtocol): the original protocol (probably
-            HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+        wrapped_protocol: the original protocol (probably HTTPChannel or
+            TLSMemoryBIOProtocol, but could be anything really)
 
-        connected_deferred (Deferred): a Deferred which will be callbacked with
+        connected_deferred: a Deferred which will be callbacked with
             wrapped_protocol when the CONNECT completes
+
+        headers: Extra HTTP headers to include in the CONNECT request
     """
 
-    def __init__(self, host, port, wrapped_protocol, connected_deferred):
+    def __init__(
+        self,
+        host: bytes,
+        port: int,
+        wrapped_protocol: Protocol,
+        connected_deferred: defer.Deferred,
+        headers: Headers,
+    ):
         self.host = host
         self.port = port
         self.wrapped_protocol = wrapped_protocol
         self.connected_deferred = connected_deferred
-        self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+        self.headers = headers
+
+        self.http_setup_client = HTTPConnectSetupClient(
+            self.host, self.port, self.headers
+        )
         self.http_setup_client.on_connected.addCallback(self.proxyConnected)
 
     def connectionMade(self):
@@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if buf:
             self.wrapped_protocol.dataReceived(buf)
 
-    def dataReceived(self, data):
+    def dataReceived(self, data: bytes):
         # if we've set up the HTTP protocol, we can send the data there
         if self.wrapped_protocol.connected:
             return self.wrapped_protocol.dataReceived(data)
@@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient):
     """HTTPClient protocol to send a CONNECT message for proxies and read the response.
 
     Args:
-        host (bytes): The hostname to send in the CONNECT message
-        port (int): The port to send in the CONNECT message
+        host: The hostname to send in the CONNECT message
+        port: The port to send in the CONNECT message
+        headers: Extra headers to send with the CONNECT message
     """
 
-    def __init__(self, host, port):
+    def __init__(self, host: bytes, port: int, headers: Headers):
         self.host = host
         self.port = port
+        self.headers = headers
         self.on_connected = defer.Deferred()
 
     def connectionMade(self):
         logger.debug("Connected to proxy, sending CONNECT")
         self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+
+        # Send any additional specified headers
+        for name, values in self.headers.getAllRawHeaders():
+            for value in values:
+                self.sendHeader(name, value)
+
         self.endHeaders()
 
-    def handleStatus(self, version, status, message):
+    def handleStatus(self, version: bytes, status: bytes, message: bytes):
         logger.debug("Got Status: %s %s %s", status, message, version)
         if status != b"200":
             raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)