diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/http/client.py | 101 |
1 files changed, 67 insertions, 34 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py index e5b13593f2..37ccf5ab98 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -32,7 +32,7 @@ from typing import ( import treq from canonicaljson import encode_canonical_json -from netaddr import IPAddress, IPSet +from netaddr import AddrFormatError, IPAddress, IPSet from prometheus_client import Counter from zope.interface import implementer, provider @@ -125,7 +125,7 @@ def _make_scheduler(reactor): return _scheduler -class IPBlacklistingResolver: +class _IPBlacklistingResolver: """ A proxy for reactor.nameResolver which only produces non-blacklisted IP addresses, preventing DNS rebinding attacks on URL preview. @@ -199,6 +199,35 @@ class IPBlacklistingResolver: return r +@implementer(IReactorPluggableNameResolver) +class BlacklistingReactorWrapper: + """ + A Reactor wrapper which will prevent DNS resolution to blacklisted IP + addresses, to prevent DNS rebinding. + """ + + def __init__( + self, + reactor: IReactorPluggableNameResolver, + ip_whitelist: Optional[IPSet], + ip_blacklist: IPSet, + ): + self._reactor = reactor + + # We need to use a DNS resolver which filters out blacklisted IP + # addresses, to prevent DNS rebinding. + self._nameResolver = _IPBlacklistingResolver( + self._reactor, ip_whitelist, ip_blacklist + ) + + def __getattr__(self, attr: str) -> Any: + # Passthrough to the real reactor except for the DNS resolver. + if attr == "nameResolver": + return self._nameResolver + else: + return getattr(self._reactor, attr) + + class BlacklistingAgentWrapper(Agent): """ An Agent wrapper which will prevent access to IP addresses being accessed @@ -232,16 +261,16 @@ class BlacklistingAgentWrapper(Agent): try: ip_address = IPAddress(h.hostname) - + except AddrFormatError: + # Not an IP + pass + else: if check_against_blacklist( ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info("Blocking access to %s due to blacklist" % (ip_address,)) e = SynapseError(403, "IP address blocked by IP blacklist entry") return defer.fail(Failure(e)) - except Exception: - # Not an IP - pass return self._agent.request( method, uri, headers=headers, bodyProducer=bodyProducer @@ -292,22 +321,11 @@ class SimpleHttpClient: self.user_agent = self.user_agent.encode("ascii") if self._ip_blacklist: - real_reactor = hs.get_reactor() # If we have an IP blacklist, we need to use a DNS resolver which # filters out blacklisted IP addresses, to prevent DNS rebinding. - nameResolver = IPBlacklistingResolver( - real_reactor, self._ip_whitelist, self._ip_blacklist + self.reactor = BlacklistingReactorWrapper( + hs.get_reactor(), self._ip_whitelist, self._ip_blacklist ) - - @implementer(IReactorPluggableNameResolver) - class Reactor: - def __getattr__(_self, attr): - if attr == "nameResolver": - return nameResolver - else: - return getattr(real_reactor, attr) - - self.reactor = Reactor() else: self.reactor = hs.get_reactor() @@ -323,6 +341,7 @@ class SimpleHttpClient: self.agent = ProxyAgent( self.reactor, + hs.get_reactor(), connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, @@ -702,11 +721,14 @@ class SimpleHttpClient: try: length = await make_deferred_yieldable( - readBodyToFile(response, output_stream, max_size) + read_body_with_max_size(response, output_stream, max_size) + ) + except BodyExceededMaxSize: + raise SynapseError( + 502, + "Requested file is too large > %r bytes" % (max_size,), + Codes.TOO_LARGE, ) - except SynapseError: - # This can happen e.g. because the body is too large. - raise except Exception as e: raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e @@ -730,7 +752,11 @@ def _timeout_to_request_timed_out_error(f: Failure): return f -class _ReadBodyToFileProtocol(protocol.Protocol): +class BodyExceededMaxSize(Exception): + """The maximum allowed size of the HTTP body was exceeded.""" + + +class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): def __init__( self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] ): @@ -740,20 +766,24 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.max_size = max_size def dataReceived(self, data: bytes) -> None: + # If the deferred was called, bail early. + if self.deferred.called: + return + self.stream.write(data) self.length += len(data) + # The first time the maximum size is exceeded, error and cancel the + # connection. dataReceived might be called again if data was received + # in the meantime. if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback( - SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - ) - ) - self.deferred = defer.Deferred() + self.deferred.errback(BodyExceededMaxSize()) self.transport.loseConnection() def connectionLost(self, reason: Failure) -> None: + # If the maximum size was already exceeded, there's nothing to do. + if self.deferred.called: + return + if reason.check(ResponseDone): self.deferred.callback(self.length) elif reason.check(PotentialDataLoss): @@ -764,12 +794,15 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.deferred.errback(reason) -def readBodyToFile( +def read_body_with_max_size( response: IResponse, stream: BinaryIO, max_size: Optional[int] ) -> defer.Deferred: """ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. + If the maximum file size is reached, the returned Deferred will resolve to a + Failure with a BodyExceededMaxSize exception. + Args: response: The HTTP response to read from. stream: The file-object to write to. @@ -780,7 +813,7 @@ def readBodyToFile( """ d = defer.Deferred() - response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) + response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size)) return d |