diff --git a/synapse/http/client.py b/synapse/http/client.py
index af34d583ad..1e01e0a9f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -39,12 +39,15 @@ from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
+from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
+ ITCPTransport,
)
+from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -56,7 +59,13 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import (
+ UNKNOWN_LENGTH,
+ IAgent,
+ IBodyProducer,
+ IPolicyForHTTPS,
+ IResponse,
+)
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -151,16 +160,17 @@ class _IPBlacklistingResolver:
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
-
- r = recv()
addresses = [] # type: List[IAddress]
def _callback() -> None:
- r.resolutionBegan(None)
-
has_bad_ip = False
- for i in addresses:
- ip_address = IPAddress(i.host)
+ for address in addresses:
+ # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
+ # should go through this path.
+ if not isinstance(address, (IPv4Address, IPv6Address)):
+ continue
+
+ ip_address = IPAddress(address.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
@@ -175,15 +185,15 @@ class _IPBlacklistingResolver:
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
- for i in addresses:
- r.addressResolved(i)
- r.resolutionComplete()
+ for address in addresses:
+ recv.addressResolved(address)
+ recv.resolutionComplete()
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
- pass
+ recv.resolutionBegan(resolutionInProgress)
@staticmethod
def addressResolved(address: IAddress) -> None:
@@ -197,7 +207,7 @@ class _IPBlacklistingResolver:
EndpointReceiver, hostname, portNumber=portNumber
)
- return r
+ return recv
@implementer(ISynapseReactor)
@@ -346,7 +356,7 @@ class SimpleHttpClient:
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
- )
+ ) # type: IAgent
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
@@ -752,6 +762,8 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
+ transport = None # type: Optional[ITCPTransport]
+
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
@@ -763,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
+ assert self.transport is not None
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
- def connectionLost(self, reason: Failure) -> None:
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
+ transport = None # type: Optional[ITCPTransport]
+
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -797,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
+ assert self.transport is not None
self.transport.abortConnection()
- def connectionLost(self, reason: Failure) -> None:
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
@@ -868,6 +884,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
return query_str.encode("utf8")
+@implementer(IPolicyForHTTPS)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
|