summary refs log tree commit diff
path: root/synapse/http/matrixfederationclient.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/matrixfederationclient.py')
-rw-r--r--synapse/http/matrixfederationclient.py88
1 files changed, 82 insertions, 6 deletions
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index fc5b5ab809..8f4db59c75 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -14,10 +14,11 @@
 # limitations under the License.
 
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor, protocol
 from twisted.internet.error import DNSLookupError
 from twisted.web.client import readBody, _AgentBase, _URI
 from twisted.web.http_headers import Headers
+from twisted.web._newclient import ResponseDone
 
 from synapse.http.endpoint import matrix_federation_endpoint
 from synapse.util.async import sleep
@@ -25,7 +26,7 @@ from synapse.util.logcontext import PreserveLoggingContext
 
 from syutil.jsonutil import encode_canonical_json
 
-from synapse.api.errors import CodeMessageException, SynapseError
+from synapse.api.errors import CodeMessageException, SynapseError, Codes
 
 from syutil.crypto.jsonsign import sign_json
 
@@ -244,7 +245,7 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
-        """ Get's some json from the given host homeserver and path
+        """ GETs some json from the given host homeserver and path
 
         Args:
             destination (str): The remote server to send the HTTP request
@@ -252,9 +253,6 @@ class MatrixFederationHttpClient(object):
             path (str): The HTTP path.
             args (dict): A dictionary used to create query strings, defaults to
                 None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-
         Returns:
             Deferred: Succeeds when we get *any* HTTP response.
 
@@ -289,6 +287,52 @@ class MatrixFederationHttpClient(object):
 
         defer.returnValue(json.loads(body))
 
+    @defer.inlineCallbacks
+    def get_file(self, destination, path, output_stream, args={},
+                 retry_on_dns_fail=True, max_size=None):
+        """GETs a file from a given homeserver
+        Args:
+            destination (str): The remote server to send the HTTP request to.
+            path (str): The HTTP path to GET.
+            output_stream (file): File to write the response body to.
+            args (dict): Optional dictionary used to create the query string.
+        Returns:
+            A (int,dict) tuple of the file length and a dict of the response
+            headers.
+        """
+
+        encoded_args = {}
+        for k, vs in args.items():
+            if isinstance(vs, basestring):
+                vs = [vs]
+            encoded_args[k] = [v.encode("UTF-8") for v in vs]
+
+        query_bytes = urllib.urlencode(encoded_args, True)
+        logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
+
+        def body_callback(method, url_bytes, headers_dict):
+            self.sign_request(destination, method, url_bytes, headers_dict)
+            return None
+
+        response = yield self._create_request(
+            destination.encode("ascii"),
+            "GET",
+            path.encode("ascii"),
+            query_bytes=query_bytes,
+            body_callback=body_callback,
+            retry_on_dns_fail=retry_on_dns_fail
+        )
+
+        headers = dict(response.headers.getAllRawHeaders())
+
+        try:
+            length = yield _readBodyToFile(response, output_stream, max_size)
+        except:
+            logger.exception("Failed to download body")
+            raise
+
+        defer.returnValue((length, headers))
+
     def _getEndpoint(self, reactor, destination):
         return matrix_federation_endpoint(
             reactor, destination, timeout=10,
@@ -296,6 +340,38 @@ class MatrixFederationHttpClient(object):
         )
 
 
+class _ReadBodyToFileProtocol(protocol.Protocol):
+    def __init__(self, stream, deferred, max_size):
+        self.stream = stream
+        self.deferred = deferred
+        self.length = 0
+        self.max_size = max_size
+
+    def dataReceived(self, data):
+        self.stream.write(data)
+        self.length += len(data)
+        if self.max_size is not None and self.length >= self.max_size:
+            self.deferred.errback(SynapseError(
+                502,
+                "Requested file is too large > %r bytes" % (self.max_size,),
+                Codes.TOO_LARGE,
+            ))
+            self.deferred = defer.Deferred()
+            self.transport.loseConnection()
+
+    def connectionLost(self, reason):
+        if reason.check(ResponseDone):
+            self.deferred.callback(self.length)
+        else:
+            self.deferred.errback(reason)
+
+
+def _readBodyToFile(response, stream, max_size):
+    d = defer.Deferred()
+    response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
+    return d
+
+
 def _print_ex(e):
     if hasattr(e, "reasons") and e.reasons:
         for ex in e.reasons: