summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8372.misc1
-rw-r--r--synapse/appservice/api.py2
-rw-r--r--synapse/http/client.py187
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py14
4 files changed, 143 insertions, 61 deletions
diff --git a/changelog.d/8372.misc b/changelog.d/8372.misc
new file mode 100644
index 0000000000..a56e36de4b
--- /dev/null
+++ b/changelog.d/8372.misc
@@ -0,0 +1 @@
+Add type annotations to `SimpleHttpClient`.
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 1514c0f691..c526c28b93 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                 urllib.parse.quote(protocol),
             )
             try:
-                info = await self.get_json(uri, {})
+                info = await self.get_json(uri)
 
                 if not _is_valid_3pe_metadata(info):
                     logger.warning(
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 13fcab3378..4694adc400 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -17,6 +17,18 @@
 import logging
 import urllib
 from io import BytesIO
+from typing import (
+    Any,
+    BinaryIO,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+)
 
 import treq
 from canonicaljson import encode_canonical_json
@@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone
 from twisted.web.client import Agent, HTTPConnectionPool, readBody
 from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
 
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.http import (
@@ -57,6 +70,19 @@ incoming_responses_counter = Counter(
     "synapse_http_client_responses", "", ["method", "code"]
 )
 
+# the type of the headers list, to be passed to the t.w.h.Headers.
+# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
+# we simplify.
+RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
+
+# the value actually has to be a List, but List is invariant so we can't specify that
+# the entries can either be Lists or bytes.
+RawHeaderValue = Sequence[Union[str, bytes]]
+
+# the type of the query params, to be passed into `urlencode`
+QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
+QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
+
 
 def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
     """
@@ -285,13 +311,26 @@ class SimpleHttpClient:
                 ip_blacklist=self._ip_blacklist,
             )
 
-    async def request(self, method, uri, data=None, headers=None):
+    async def request(
+        self,
+        method: str,
+        uri: str,
+        data: Optional[bytes] = None,
+        headers: Optional[Headers] = None,
+    ) -> IResponse:
         """
         Args:
-            method (str): HTTP method to use.
-            uri (str): URI to query.
-            data (bytes): Data to send in the request body, if applicable.
-            headers (t.w.http_headers.Headers): Request headers.
+            method: HTTP method to use.
+            uri: URI to query.
+            data: Data to send in the request body, if applicable.
+            headers: Request headers.
+
+        Returns:
+            Response object, once the headers have been read.
+
+        Raises:
+            RequestTimedOutError if the request times out before the headers are read
+
         """
         # A small wrapper around self.agent.request() so we can easily attach
         # counters to it
@@ -324,6 +363,8 @@ class SimpleHttpClient:
                     headers=headers,
                     **self._extra_treq_args
                 )
+                # we use our own timeout mechanism rather than treq's as a workaround
+                # for https://twistedmatrix.com/trac/ticket/9534.
                 request_deferred = timeout_deferred(
                     request_deferred,
                     60,
@@ -353,18 +394,26 @@ class SimpleHttpClient:
                 set_tag("error_reason", e.args[0])
                 raise
 
-    async def post_urlencoded_get_json(self, uri, args={}, headers=None):
+    async def post_urlencoded_get_json(
+        self,
+        uri: str,
+        args: Mapping[str, Union[str, List[str]]] = {},
+        headers: Optional[RawHeaders] = None,
+    ) -> Any:
         """
         Args:
-            uri (str):
-            args (dict[str, str|List[str]]): query params
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: uri to query
+            args: parameters to be url-encoded in the body
+            headers: a map from header name to a list of values for that header
 
         Returns:
-            object: parsed json
+            parsed json
 
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException: On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -398,19 +447,24 @@ class SimpleHttpClient:
                 response.code, response.phrase.decode("ascii", errors="replace"), body
             )
 
-    async def post_json_get_json(self, uri, post_json, headers=None):
+    async def post_json_get_json(
+        self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
+    ) -> Any:
         """
 
         Args:
-            uri (str):
-            post_json (object):
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: URI to query.
+            post_json: request body, to be encoded as json
+            headers: a map from header name to a list of values for that header
 
         Returns:
-            object: parsed json
+            parsed json
 
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException: On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -440,21 +494,22 @@ class SimpleHttpClient:
                 response.code, response.phrase.decode("ascii", errors="replace"), body
             )
 
-    async def get_json(self, uri, args={}, headers=None):
-        """ Gets some json from the given URI.
+    async def get_json(
+        self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+    ) -> Any:
+        """Gets some json from the given URI.
 
         Args:
-            uri (str): The URI to request, not including query parameters
-            args (dict): A dictionary used to create query strings, defaults to
-                None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: The URI to request, not including query parameters
+            args: A dictionary used to create query string
+            headers: a map from header name to a list of values for that header
         Returns:
-            Succeeds when we get *any* 2xx HTTP response, with the
-            HTTP body as JSON.
+            Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -466,22 +521,27 @@ class SimpleHttpClient:
         body = await self.get_raw(uri, args, headers=headers)
         return json_decoder.decode(body.decode("utf-8"))
 
-    async def put_json(self, uri, json_body, args={}, headers=None):
-        """ Puts some json to the given URI.
+    async def put_json(
+        self,
+        uri: str,
+        json_body: Any,
+        args: QueryParams = {},
+        headers: RawHeaders = None,
+    ) -> Any:
+        """Puts some json to the given URI.
 
         Args:
-            uri (str): The URI to request, not including query parameters
-            json_body (dict): The JSON to put in the HTTP body,
-            args (dict): A dictionary used to create query strings, defaults to
-                None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: The URI to request, not including query parameters
+            json_body: The JSON to put in the HTTP body,
+            args: A dictionary used to create query strings
+            headers: a map from header name to a list of values for that header
         Returns:
-            Succeeds when we get *any* 2xx HTTP response, with the
-            HTTP body as JSON.
+            Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
         Raises:
+             RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException On a non-2xx HTTP response.
 
             ValueError: if the response was not JSON
@@ -513,21 +573,23 @@ class SimpleHttpClient:
                 response.code, response.phrase.decode("ascii", errors="replace"), body
             )
 
-    async def get_raw(self, uri, args={}, headers=None):
-        """ Gets raw text from the given URI.
+    async def get_raw(
+        self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+    ) -> bytes:
+        """Gets raw text from the given URI.
 
         Args:
-            uri (str): The URI to request, not including query parameters
-            args (dict): A dictionary used to create query strings, defaults to
-                None.
-                **Note**: The value of each key is assumed to be an iterable
-                and *not* a string.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            uri: The URI to request, not including query parameters
+            args: A dictionary used to create query strings
+            headers: a map from header name to a list of values for that header
         Returns:
-            Succeeds when we get *any* 2xx HTTP response, with the
+            Succeeds when we get a 2xx HTTP response, with the
             HTTP body as bytes.
         Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
             HttpResponseException on a non-2xx HTTP response.
         """
         if len(args):
@@ -552,16 +614,29 @@ class SimpleHttpClient:
     # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
     # The two should be factored out.
 
-    async def get_file(self, url, output_stream, max_size=None, headers=None):
+    async def get_file(
+        self,
+        url: str,
+        output_stream: BinaryIO,
+        max_size: Optional[int] = None,
+        headers: Optional[RawHeaders] = None,
+    ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
         """GETs a file from a given URL
         Args:
-            url (str): The URL to GET
-            output_stream (file): File to write the response body to.
-            headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
-               header name to a list of values for that header
+            url: The URL to GET
+            output_stream: File to write the response body to.
+            headers: A map from header name to a list of values for that header
         Returns:
-            A (int,dict,string,int) tuple of the file length, dict of the response
+            A tuple of the file length, dict of the response
             headers, absolute URI of the response and HTTP response code.
+
+        Raises:
+            RequestTimedOutException: if there is a timeout before the response headers
+               are received. Note there is currently no timeout on reading the response
+               body.
+
+            SynapseError: if the response is not a 2xx, the remote file is too large, or
+               another exception happens during the download.
         """
 
         actual_headers = {b"User-Agent": [self.user_agent]}
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 987765e877..dce6c4d168 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
             raise OEmbedError() from e
 
-    async def _download_url(self, url, user):
+    async def _download_url(self, url: str, user):
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
 
         # If this URL can be accessed via oEmbed, use that instead.
-        url_to_download = url
+        url_to_download = url  # type: Optional[str]
         oembed_url = self._get_oembed_url(url)
         if oembed_url:
             # The result might be a new URL to download, or it might be HTML content.
@@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
                 # FIXME: we should calculate a proper expiration based on the
                 # Cache-Control and Expire headers.  But for now, assume 1 hour.
                 expires = ONE_HOUR
-                etag = headers["ETag"][0] if "ETag" in headers else None
+                etag = (
+                    headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+                )
         else:
-            html_bytes = oembed_result.html.encode("utf-8")  # type: ignore
+            # we can only get here if we did an oembed request and have an oembed_result.html
+            assert oembed_result.html is not None
+            assert oembed_url is not None
+
+            html_bytes = oembed_result.html.encode("utf-8")
             with self.media_storage.store_into_file(file_info) as (f, fname, finish):
                 f.write(html_bytes)
                 await finish()