diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index f05269cdfb..8f4db59c75 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -26,7 +26,7 @@ from synapse.util.logcontext import PreserveLoggingContext
from syutil.jsonutil import encode_canonical_json
-from synapse.api.errors import CodeMessageException, SynapseError
+from synapse.api.errors import CodeMessageException, SynapseError, Codes
from syutil.crypto.jsonsign import sign_json
@@ -289,7 +289,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
- retry_on_dns_fail=True):
+ retry_on_dns_fail=True, max_size=None):
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
@@ -325,7 +325,11 @@ class MatrixFederationHttpClient(object):
headers = dict(response.headers.getAllRawHeaders())
- length = yield _readBodyToFile(response, output_stream)
+ try:
+ length = yield _readBodyToFile(response, output_stream, max_size)
+ except:
+ logger.exception("Failed to download body")
+ raise
defer.returnValue((length, headers))
@@ -337,14 +341,23 @@ class MatrixFederationHttpClient(object):
class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred):
+ def __init__(self, stream, deferred, max_size):
self.stream = stream
self.deferred = deferred
self.length = 0
+ self.max_size = max_size
def dataReceived(self, data):
self.stream.write(data)
self.length += len(data)
+ if self.max_size is not None and self.length >= self.max_size:
+ self.deferred.errback(SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (self.max_size,),
+ Codes.TOO_LARGE,
+ ))
+ self.deferred = defer.Deferred()
+ self.transport.loseConnection()
def connectionLost(self, reason):
if reason.check(ResponseDone):
@@ -353,9 +366,9 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-def _readBodyToFile(response, stream):
+def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d))
+ response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
|