summary refs log tree commit diff
path: root/synapse/http/connectproxyclient.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/connectproxyclient.py')
-rw-r--r--synapse/http/connectproxyclient.py39
1 files changed, 23 insertions, 16 deletions
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 203e995bb7..23a60af171 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -14,15 +14,22 @@
 
 import base64
 import logging
-from typing import Optional
+from typing import Optional, Union
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet import defer, protocol
 from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+    IAddress,
+    IConnector,
+    IProtocol,
+    IReactorCore,
+    IStreamClientEndpoint,
+)
 from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
+from twisted.python.failure import Failure
 from twisted.web import http
 
 logger = logging.getLogger(__name__)
@@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint:
         self._port = port
         self._proxy_creds = proxy_creds
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 
     # Mypy encounters a false positive here: it complains that ClientFactory
     # is incompatible with IProtocolFactory. But ClientFactory inherits from
     # Factory, which implements IProtocolFactory. So I think this is a bug
     # in mypy-zope.
-    def connect(self, protocolFactory: ClientFactory):  # type: ignore[override]
+    def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]":  # type: ignore[override]
         f = HTTPProxiedClientFactory(
             self._host, self._port, protocolFactory, self._proxy_creds
         )
@@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         self.proxy_creds = proxy_creds
         self.on_connection: "defer.Deferred[None]" = defer.Deferred()
 
-    def startedConnecting(self, connector):
+    def startedConnecting(self, connector: IConnector) -> None:
         return self.wrapped_factory.startedConnecting(connector)
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
         wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
         if wrapped_protocol is None:
             raise TypeError("buildProtocol produced None instead of a Protocol")
@@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
             self.proxy_creds,
         )
 
-    def clientConnectionFailed(self, connector, reason):
+    def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy failed: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
         return self.wrapped_factory.clientConnectionFailed(connector, reason)
 
-    def clientConnectionLost(self, connector, reason):
+    def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy lost: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
@@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol):
         )
         self.http_setup_client.on_connected.addCallback(self.proxyConnected)
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         self.http_setup_client.makeConnection(self.transport)
 
-    def connectionLost(self, reason=connectionDone):
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         if self.wrapped_protocol.connected:
             self.wrapped_protocol.connectionLost(reason)
 
@@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if not self.connected_deferred.called:
             self.connected_deferred.errback(reason)
 
-    def proxyConnected(self, _):
+    def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
         self.wrapped_protocol.makeConnection(self.transport)
 
         self.connected_deferred.callback(self.wrapped_protocol)
@@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if buf:
             self.wrapped_protocol.dataReceived(buf)
 
-    def dataReceived(self, data: bytes):
+    def dataReceived(self, data: bytes) -> None:
         # if we've set up the HTTP protocol, we can send the data there
         if self.wrapped_protocol.connected:
             return self.wrapped_protocol.dataReceived(data)
@@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
         self.proxy_creds = proxy_creds
         self.on_connected: "defer.Deferred[None]" = defer.Deferred()
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         logger.debug("Connected to proxy, sending CONNECT")
         self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
 
@@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient):
 
         self.endHeaders()
 
-    def handleStatus(self, version: bytes, status: bytes, message: bytes):
+    def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
         logger.debug("Got Status: %s %s %s", status, message, version)
         if status != b"200":
             raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
 
-    def handleEndHeaders(self):
+    def handleEndHeaders(self) -> None:
         logger.debug("End Headers")
         self.on_connected.callback(None)
 
-    def handleResponse(self, body):
+    def handleResponse(self, body: bytes) -> None:
         pass