diff --git a/synapse/http/client.py b/synapse/http/client.py
index e5b13593f2..37ccf5ab98 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -32,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
-from netaddr import IPAddress, IPSet
+from netaddr import AddrFormatError, IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -125,7 +125,7 @@ def _make_scheduler(reactor):
return _scheduler
-class IPBlacklistingResolver:
+class _IPBlacklistingResolver:
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
addresses, preventing DNS rebinding attacks on URL preview.
@@ -199,6 +199,35 @@ class IPBlacklistingResolver:
return r
+@implementer(IReactorPluggableNameResolver)
+class BlacklistingReactorWrapper:
+ """
+ A Reactor wrapper which will prevent DNS resolution to blacklisted IP
+ addresses, to prevent DNS rebinding.
+ """
+
+ def __init__(
+ self,
+ reactor: IReactorPluggableNameResolver,
+ ip_whitelist: Optional[IPSet],
+ ip_blacklist: IPSet,
+ ):
+ self._reactor = reactor
+
+ # We need to use a DNS resolver which filters out blacklisted IP
+ # addresses, to prevent DNS rebinding.
+ self._nameResolver = _IPBlacklistingResolver(
+ self._reactor, ip_whitelist, ip_blacklist
+ )
+
+ def __getattr__(self, attr: str) -> Any:
+ # Passthrough to the real reactor except for the DNS resolver.
+ if attr == "nameResolver":
+ return self._nameResolver
+ else:
+ return getattr(self._reactor, attr)
+
+
class BlacklistingAgentWrapper(Agent):
"""
An Agent wrapper which will prevent access to IP addresses being accessed
@@ -232,16 +261,16 @@ class BlacklistingAgentWrapper(Agent):
try:
ip_address = IPAddress(h.hostname)
-
+ except AddrFormatError:
+ # Not an IP
+ pass
+ else:
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
- except Exception:
- # Not an IP
- pass
return self._agent.request(
method, uri, headers=headers, bodyProducer=bodyProducer
@@ -292,22 +321,11 @@ class SimpleHttpClient:
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
- real_reactor = hs.get_reactor()
# If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding.
- nameResolver = IPBlacklistingResolver(
- real_reactor, self._ip_whitelist, self._ip_blacklist
+ self.reactor = BlacklistingReactorWrapper(
+ hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
)
-
- @implementer(IReactorPluggableNameResolver)
- class Reactor:
- def __getattr__(_self, attr):
- if attr == "nameResolver":
- return nameResolver
- else:
- return getattr(real_reactor, attr)
-
- self.reactor = Reactor()
else:
self.reactor = hs.get_reactor()
@@ -323,6 +341,7 @@ class SimpleHttpClient:
self.agent = ProxyAgent(
self.reactor,
+ hs.get_reactor(),
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
@@ -702,11 +721,14 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
- readBodyToFile(response, output_stream, max_size)
+ read_body_with_max_size(response, output_stream, max_size)
+ )
+ except BodyExceededMaxSize:
+ raise SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (max_size,),
+ Codes.TOO_LARGE,
)
- except SynapseError:
- # This can happen e.g. because the body is too large.
- raise
except Exception as e:
raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
@@ -730,7 +752,11 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
-class _ReadBodyToFileProtocol(protocol.Protocol):
+class BodyExceededMaxSize(Exception):
+ """The maximum allowed size of the HTTP body was exceeded."""
+
+
+class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -740,20 +766,24 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.max_size = max_size
def dataReceived(self, data: bytes) -> None:
+ # If the deferred was called, bail early.
+ if self.deferred.called:
+ return
+
self.stream.write(data)
self.length += len(data)
+ # The first time the maximum size is exceeded, error and cancel the
+ # connection. dataReceived might be called again if data was received
+ # in the meantime.
if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(
- SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- )
- )
- self.deferred = defer.Deferred()
+ self.deferred.errback(BodyExceededMaxSize())
self.transport.loseConnection()
def connectionLost(self, reason: Failure) -> None:
+ # If the maximum size was already exceeded, there's nothing to do.
+ if self.deferred.called:
+ return
+
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@@ -764,12 +794,15 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-def readBodyToFile(
+def read_body_with_max_size(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+ If the maximum file size is reached, the returned Deferred will resolve to a
+ Failure with a BodyExceededMaxSize exception.
+
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
@@ -780,7 +813,7 @@ def readBodyToFile(
"""
d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
+ response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d
|