summary refs log tree commit diff
path: root/synapse/http/matrixfederationclient.py
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2014-12-04 14:22:31 +0000
committerMark Haines <mark.haines@matrix.org>2014-12-04 14:22:31 +0000
commitc01fd5573c92c7c6da258bac7ff377a91cbebfd1 (patch)
tree7562ac0c1e8b8ac8821c3255cbdbcc8d3fc15031 /synapse/http/matrixfederationclient.py
parentFix pyflakes and pep8 warnings (diff)
downloadsynapse-c01fd5573c92c7c6da258bac7ff377a91cbebfd1.tar.xz
Implement download support for media_repository
Diffstat (limited to 'synapse/http/matrixfederationclient.py')
-rw-r--r--synapse/http/matrixfederationclient.py73
1 files changed, 68 insertions, 5 deletions
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 510f07dd7b..c7082b83a7 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -14,10 +14,11 @@
 # limitations under the License.
 
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor, protocol
 from twisted.internet.error import DNSLookupError
 from twisted.web.client import readBody, _AgentBase, _URI
 from twisted.web.http_headers import Headers
+from twisted.web._newclient import ResponseDone
 
 from synapse.http.endpoint import matrix_federation_endpoint
 from synapse.util.async import sleep
@@ -227,7 +228,7 @@ class MatrixFederationHttpClient(object):
 
     @defer.inlineCallbacks
     def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
-        """ Get's some json from the given host homeserver and path
+        """ GETs some json from the given host homeserver and path
 
         Args:
             destination (str): The remote server to send the HTTP request
@@ -235,9 +236,6 @@ class MatrixFederationHttpClient(object):
             path (str): The HTTP path.
             args (dict): A dictionary used to create query strings, defaults to
                 None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-
         Returns:
             Deferred: Succeeds when we get *any* HTTP response.
 
@@ -272,6 +270,48 @@ class MatrixFederationHttpClient(object):
 
         defer.returnValue(json.loads(body))
 
+    @defer.inlineCallbacks
+    def get_file(self, destination, path, output_stream, args={},
+                 retry_on_dns_fail=True):
+        """GETs a file from a given homeserver
+        Args:
+            destination (str): The remote server to send the HTTP request to.
+            path (str): The HTTP path to GET.
+            output_stream (file): File to write the response body to.
+            args (dict): Optional dictionary used to create the query string.
+        Returns:
+            A (int,dict) tuple of the file length and a dict of the response
+            headers.
+        """
+
+        encoded_args = {}
+        for k, vs in args.items():
+            if isinstance(vs, basestring):
+                vs = [vs]
+            encoded_args[k] = [v.encode("UTF-8") for v in vs]
+
+        query_bytes = urllib.urlencode(encoded_args, True)
+        logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
+
+        def body_callback(method, url_bytes, headers_dict):
+            self.sign_request(destination, method, url_bytes, headers_dict)
+            return None
+
+        response = yield self._create_request(
+            destination.encode("ascii"),
+            "GET",
+            path.encode("ascii"),
+            query_bytes=query_bytes,
+            body_callback=body_callback,
+            retry_on_dns_fail=retry_on_dns_fail
+        )
+
+        headers = dict(response.headers.getAllRawHeaders())
+
+        length = yield _readBodyToFile(response, output_stream)
+
+        defer.returnValue((length, headers))
+
     def _getEndpoint(self, reactor, destination):
         return matrix_federation_endpoint(
             reactor, destination, timeout=10,
@@ -279,6 +319,29 @@ class MatrixFederationHttpClient(object):
         )
 
 
+class _ReadBodyToFileProtocol(protocol.Protocol):
+    def __init__(self, stream, deferred):
+        self.stream = stream
+        self.deferred = deferred
+        self.length = 0
+
+    def dataReceived(self, data):
+        self.stream.write(data)
+        self.length += len(data)
+
+    def connectionLost(self, reason):
+        if reason.check(ResponseDone):
+            self.deferred.callback(self.length)
+        else:
+            self.deferred.errback(reason)
+
+
+def _readBodyToFile(response, stream):
+    d = defer.Deferred()
+    response.deliverBody(_ReadBodyToFileProtocol(stream, d))
+    return d
+
+
 def _print_ex(e):
     if hasattr(e, "reasons") and e.reasons:
         for ex in e.reasons: