diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 856e28454f..b797e3ce80 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -19,9 +19,10 @@ from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.internet.protocol import connectionDone
+from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
from twisted.web import http
+from twisted.web.http_headers import Headers
logger = logging.getLogger(__name__)
@@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint:
Args:
reactor: the Twisted reactor to use for the connection
- proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
- proxy
- host (bytes): hostname that we want to CONNECT to
- port (int): port that we want to connect to
+ proxy_endpoint: the endpoint to use to connect to the proxy
+ host: hostname that we want to CONNECT to
+ port: port that we want to connect to
+ headers: Extra HTTP headers to include in the CONNECT request
"""
- def __init__(self, reactor, proxy_endpoint, host, port):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ proxy_endpoint: IStreamClientEndpoint,
+ host: bytes,
+ port: int,
+ headers: Headers,
+ ):
self._reactor = reactor
self._proxy_endpoint = proxy_endpoint
self._host = host
self._port = port
+ self._headers = headers
def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
- def connect(self, protocolFactory):
- f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+ def connect(self, protocolFactory: ClientFactory):
+ f = HTTPProxiedClientFactory(
+ self._host, self._port, protocolFactory, self._headers
+ )
d = self._proxy_endpoint.connect(f)
# once the tcp socket connects successfully, we need to wait for the
# CONNECT to complete.
@@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
HTTP Protocol object and run the rest of the connection.
Args:
- dst_host (bytes): hostname that we want to CONNECT to
- dst_port (int): port that we want to connect to
- wrapped_factory (protocol.ClientFactory): The original Factory
+ dst_host: hostname that we want to CONNECT to
+ dst_port: port that we want to connect to
+ wrapped_factory: The original Factory
+ headers: Extra HTTP headers to include in the CONNECT request
"""
- def __init__(self, dst_host, dst_port, wrapped_factory):
+ def __init__(
+ self,
+ dst_host: bytes,
+ dst_port: int,
+ wrapped_factory: ClientFactory,
+ headers: Headers,
+ ):
self.dst_host = dst_host
self.dst_port = dst_port
self.wrapped_factory = wrapped_factory
+ self.headers = headers
self.on_connection = defer.Deferred()
def startedConnecting(self, connector):
@@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
return HTTPConnectProtocol(
- self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+ self.dst_host,
+ self.dst_port,
+ wrapped_protocol,
+ self.on_connection,
+ self.headers,
)
def clientConnectionFailed(self, connector, reason):
@@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol):
"""Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
Args:
- host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+ host: The original HTTP(s) hostname or IPv4 or IPv6 address literal
to put in the CONNECT request
- port (int): The original HTTP(s) port to put in the CONNECT request
+ port: The original HTTP(s) port to put in the CONNECT request
- wrapped_protocol (interfaces.IProtocol): the original protocol (probably
- HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+ wrapped_protocol: the original protocol (probably HTTPChannel or
+ TLSMemoryBIOProtocol, but could be anything really)
- connected_deferred (Deferred): a Deferred which will be callbacked with
+ connected_deferred: a Deferred which will be callbacked with
wrapped_protocol when the CONNECT completes
+
+ headers: Extra HTTP headers to include in the CONNECT request
"""
- def __init__(self, host, port, wrapped_protocol, connected_deferred):
+ def __init__(
+ self,
+ host: bytes,
+ port: int,
+ wrapped_protocol: Protocol,
+ connected_deferred: defer.Deferred,
+ headers: Headers,
+ ):
self.host = host
self.port = port
self.wrapped_protocol = wrapped_protocol
self.connected_deferred = connected_deferred
- self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+ self.headers = headers
+
+ self.http_setup_client = HTTPConnectSetupClient(
+ self.host, self.port, self.headers
+ )
self.http_setup_client.on_connected.addCallback(self.proxyConnected)
def connectionMade(self):
@@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol):
if buf:
self.wrapped_protocol.dataReceived(buf)
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes):
# if we've set up the HTTP protocol, we can send the data there
if self.wrapped_protocol.connected:
return self.wrapped_protocol.dataReceived(data)
@@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient):
"""HTTPClient protocol to send a CONNECT message for proxies and read the response.
Args:
- host (bytes): The hostname to send in the CONNECT message
- port (int): The port to send in the CONNECT message
+ host: The hostname to send in the CONNECT message
+ port: The port to send in the CONNECT message
+ headers: Extra headers to send with the CONNECT message
"""
- def __init__(self, host, port):
+ def __init__(self, host: bytes, port: int, headers: Headers):
self.host = host
self.port = port
+ self.headers = headers
self.on_connected = defer.Deferred()
def connectionMade(self):
logger.debug("Connected to proxy, sending CONNECT")
self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+
+ # Send any additional specified headers
+ for name, values in self.headers.getAllRawHeaders():
+ for value in values:
+ self.sendHeader(name, value)
+
self.endHeaders()
- def handleStatus(self, version, status, message):
+ def handleStatus(self, version: bytes, status: bytes, message: bytes):
logger.debug("Got Status: %s %s %s", status, message, version)
if status != b"200":
raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
|