summary refs log tree commit diff
path: root/synapse/http/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/client.py')
-rw-r--r--synapse/http/client.py127
1 files changed, 89 insertions, 38 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 9eba046bbf..4abb479ae3 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -18,7 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
 from synapse.api.errors import (
     CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
 )
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.logcontext import make_deferred_yieldable
 from synapse.util import logcontext
 import synapse.metrics
 from synapse.http.endpoint import SpiderEndpoint
@@ -114,43 +114,73 @@ class SimpleHttpClient(object):
             raise e
 
     @defer.inlineCallbacks
-    def post_urlencoded_get_json(self, uri, args={}):
+    def post_urlencoded_get_json(self, uri, args={}, headers=None):
+        """
+        Args:
+            uri (str):
+            args (dict[str, str|List[str]]): query params
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
+
+        Returns:
+            Deferred[object]: parsed json
+        """
+
         # TODO: Do we ever want to log message contents?
         logger.debug("post_urlencoded_get_json args: %s", args)
 
         query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
 
+        actual_headers = {
+            b"Content-Type": [b"application/x-www-form-urlencoded"],
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "POST",
             uri.encode("ascii"),
-            headers=Headers({
-                b"Content-Type": [b"application/x-www-form-urlencoded"],
-                b"User-Agent": [self.user_agent],
-            }),
+            headers=Headers(actual_headers),
             bodyProducer=FileBodyProducer(StringIO(query_bytes))
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def post_json_get_json(self, uri, post_json):
+    def post_json_get_json(self, uri, post_json, headers=None):
+        """
+
+        Args:
+            uri (str):
+            post_json (object):
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
+
+        Returns:
+            Deferred[object]: parsed json
+        """
         json_str = encode_canonical_json(post_json)
 
         logger.debug("HTTP POST %s -> %s", json_str, uri)
 
+        actual_headers = {
+            b"Content-Type": [b"application/json"],
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "POST",
             uri.encode("ascii"),
-            headers=Headers({
-                b"Content-Type": [b"application/json"],
-                b"User-Agent": [self.user_agent],
-            }),
+            headers=Headers(actual_headers),
             bodyProducer=FileBodyProducer(StringIO(json_str))
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
             defer.returnValue(json.loads(body))
@@ -160,7 +190,7 @@ class SimpleHttpClient(object):
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def get_json(self, uri, args={}):
+    def get_json(self, uri, args={}, headers=None):
         """ Gets some json from the given URI.
 
         Args:
@@ -169,6 +199,8 @@ class SimpleHttpClient(object):
                 None.
                 **Note**: The value of each key is assumed to be an iterable
                 and *not* a string.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             Deferred: Succeeds when we get *any* 2xx HTTP response, with the
             HTTP body as JSON.
@@ -177,13 +209,13 @@ class SimpleHttpClient(object):
             error message.
         """
         try:
-            body = yield self.get_raw(uri, args)
+            body = yield self.get_raw(uri, args, headers=headers)
             defer.returnValue(json.loads(body))
         except CodeMessageException as e:
             raise self._exceptionFromFailedRequest(e.code, e.msg)
 
     @defer.inlineCallbacks
-    def put_json(self, uri, json_body, args={}):
+    def put_json(self, uri, json_body, args={}, headers=None):
         """ Puts some json to the given URI.
 
         Args:
@@ -193,6 +225,8 @@ class SimpleHttpClient(object):
                 None.
                 **Note**: The value of each key is assumed to be an iterable
                 and *not* a string.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             Deferred: Succeeds when we get *any* 2xx HTTP response, with the
             HTTP body as JSON.
@@ -205,17 +239,21 @@ class SimpleHttpClient(object):
 
         json_str = encode_canonical_json(json_body)
 
+        actual_headers = {
+            b"Content-Type": [b"application/json"],
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "PUT",
             uri.encode("ascii"),
-            headers=Headers({
-                b"User-Agent": [self.user_agent],
-                "Content-Type": ["application/json"]
-            }),
+            headers=Headers(actual_headers),
             bodyProducer=FileBodyProducer(StringIO(json_str))
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
             defer.returnValue(json.loads(body))
@@ -226,7 +264,7 @@ class SimpleHttpClient(object):
             raise CodeMessageException(response.code, body)
 
     @defer.inlineCallbacks
-    def get_raw(self, uri, args={}):
+    def get_raw(self, uri, args={}, headers=None):
         """ Gets raw text from the given URI.
 
         Args:
@@ -235,6 +273,8 @@ class SimpleHttpClient(object):
                 None.
                 **Note**: The value of each key is assumed to be an iterable
                 and *not* a string.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             Deferred: Succeeds when we get *any* 2xx HTTP response, with the
             HTTP body at text.
@@ -246,15 +286,19 @@ class SimpleHttpClient(object):
             query_bytes = urllib.urlencode(args, True)
             uri = "%s?%s" % (uri, query_bytes)
 
+        actual_headers = {
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "GET",
             uri.encode("ascii"),
-            headers=Headers({
-                b"User-Agent": [self.user_agent],
-            })
+            headers=Headers(actual_headers),
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
             defer.returnValue(body)
@@ -274,27 +318,33 @@ class SimpleHttpClient(object):
     # The two should be factored out.
 
     @defer.inlineCallbacks
-    def get_file(self, url, output_stream, max_size=None):
+    def get_file(self, url, output_stream, max_size=None, headers=None):
         """GETs a file from a given URL
         Args:
             url (str): The URL to GET
             output_stream (file): File to write the response body to.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             A (int,dict,string,int) tuple of the file length, dict of the response
             headers, absolute URI of the response and HTTP response code.
         """
 
+        actual_headers = {
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "GET",
             url.encode("ascii"),
-            headers=Headers({
-                b"User-Agent": [self.user_agent],
-            })
+            headers=Headers(actual_headers),
         )
 
-        headers = dict(response.headers.getAllRawHeaders())
+        resp_headers = dict(response.headers.getAllRawHeaders())
 
-        if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+        if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
             logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
             raise SynapseError(
                 502,
@@ -315,10 +365,9 @@ class SimpleHttpClient(object):
         # straight back in again
 
         try:
-            length = yield preserve_context_over_fn(
-                _readBodyToFile,
-                response, output_stream, max_size
-            )
+            length = yield make_deferred_yieldable(_readBodyToFile(
+                response, output_stream, max_size,
+            ))
         except Exception as e:
             logger.exception("Failed to download body")
             raise SynapseError(
@@ -327,7 +376,9 @@ class SimpleHttpClient(object):
                 Codes.UNKNOWN,
             )
 
-        defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+        defer.returnValue(
+            (length, resp_headers, response.request.absoluteURI, response.code),
+        )
 
 
 # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@@ -395,7 +446,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
         )
 
         try:
-            body = yield preserve_context_over_fn(readBody, response)
+            body = yield make_deferred_yieldable(readBody(response))
             defer.returnValue(body)
         except PartialDownloadError as e:
             # twisted dislikes google's response, no content length.