diff --git a/changelog.d/8372.misc b/changelog.d/8372.misc
new file mode 100644
index 0000000000..a56e36de4b
--- /dev/null
+++ b/changelog.d/8372.misc
@@ -0,0 +1 @@
+Add type annotations to `SimpleHttpClient`.
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 1514c0f691..c526c28b93 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol),
)
try:
- info = await self.get_json(uri, {})
+ info = await self.get_json(uri)
if not _is_valid_3pe_metadata(info):
logger.warning(
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 13fcab3378..4694adc400 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -17,6 +17,18 @@
import logging
import urllib
from io import BytesIO
+from typing import (
+ Any,
+ BinaryIO,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
import treq
from canonicaljson import encode_canonical_json
@@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import (
@@ -57,6 +70,19 @@ incoming_responses_counter = Counter(
"synapse_http_client_responses", "", ["method", "code"]
)
+# the type of the headers list, to be passed to the t.w.h.Headers.
+# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
+# we simplify.
+RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
+
+# the value actually has to be a List, but List is invariant so we can't specify that
+# the entries can either be Lists or bytes.
+RawHeaderValue = Sequence[Union[str, bytes]]
+
+# the type of the query params, to be passed into `urlencode`
+QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
+QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
+
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
"""
@@ -285,13 +311,26 @@ class SimpleHttpClient:
ip_blacklist=self._ip_blacklist,
)
- async def request(self, method, uri, data=None, headers=None):
+ async def request(
+ self,
+ method: str,
+ uri: str,
+ data: Optional[bytes] = None,
+ headers: Optional[Headers] = None,
+ ) -> IResponse:
"""
Args:
- method (str): HTTP method to use.
- uri (str): URI to query.
- data (bytes): Data to send in the request body, if applicable.
- headers (t.w.http_headers.Headers): Request headers.
+ method: HTTP method to use.
+ uri: URI to query.
+ data: Data to send in the request body, if applicable.
+ headers: Request headers.
+
+ Returns:
+ Response object, once the headers have been read.
+
+ Raises:
+ RequestTimedOutError if the request times out before the headers are read
+
"""
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
@@ -324,6 +363,8 @@ class SimpleHttpClient:
headers=headers,
**self._extra_treq_args
)
+ # we use our own timeout mechanism rather than treq's as a workaround
+ # for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred(
request_deferred,
60,
@@ -353,18 +394,26 @@ class SimpleHttpClient:
set_tag("error_reason", e.args[0])
raise
- async def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ async def post_urlencoded_get_json(
+ self,
+ uri: str,
+ args: Mapping[str, Union[str, List[str]]] = {},
+ headers: Optional[RawHeaders] = None,
+ ) -> Any:
"""
Args:
- uri (str):
- args (dict[str, str|List[str]]): query params
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: uri to query
+ args: parameters to be url-encoded in the body
+ headers: a map from header name to a list of values for that header
Returns:
- object: parsed json
+ parsed json
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -398,19 +447,24 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def post_json_get_json(self, uri, post_json, headers=None):
+ async def post_json_get_json(
+ self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
+ ) -> Any:
"""
Args:
- uri (str):
- post_json (object):
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: URI to query.
+ post_json: request body, to be encoded as json
+ headers: a map from header name to a list of values for that header
Returns:
- object: parsed json
+ parsed json
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException: On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -440,21 +494,22 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def get_json(self, uri, args={}, headers=None):
- """ Gets some json from the given URI.
+ async def get_json(
+ self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+ ) -> Any:
+ """Gets some json from the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ args: A dictionary used to create query string
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body as JSON.
+ Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -466,22 +521,27 @@ class SimpleHttpClient:
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
- async def put_json(self, uri, json_body, args={}, headers=None):
- """ Puts some json to the given URI.
+ async def put_json(
+ self,
+ uri: str,
+ json_body: Any,
+ args: QueryParams = {},
+ headers: RawHeaders = None,
+ ) -> Any:
+ """Puts some json to the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- json_body (dict): The JSON to put in the HTTP body,
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ json_body: The JSON to put in the HTTP body,
+ args: A dictionary used to create query strings
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body as JSON.
+ Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException On a non-2xx HTTP response.
ValueError: if the response was not JSON
@@ -513,21 +573,23 @@ class SimpleHttpClient:
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- async def get_raw(self, uri, args={}, headers=None):
- """ Gets raw text from the given URI.
+ async def get_raw(
+ self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+ ) -> bytes:
+ """Gets raw text from the given URI.
Args:
- uri (str): The URI to request, not including query parameters
- args (dict): A dictionary used to create query strings, defaults to
- None.
- **Note**: The value of each key is assumed to be an iterable
- and *not* a string.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ uri: The URI to request, not including query parameters
+ args: A dictionary used to create query strings
+ headers: a map from header name to a list of values for that header
Returns:
- Succeeds when we get *any* 2xx HTTP response, with the
+ Succeeds when we get a 2xx HTTP response, with the
HTTP body as bytes.
Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
@@ -552,16 +614,29 @@ class SimpleHttpClient:
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
- async def get_file(self, url, output_stream, max_size=None, headers=None):
+ async def get_file(
+ self,
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
"""GETs a file from a given URL
Args:
- url (str): The URL to GET
- output_stream (file): File to write the response body to.
- headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
- header name to a list of values for that header
+ url: The URL to GET
+ output_stream: File to write the response body to.
+ headers: A map from header name to a list of values for that header
Returns:
- A (int,dict,string,int) tuple of the file length, dict of the response
+ A tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
+
+ Raises:
+ RequestTimedOutException: if there is a timeout before the response headers
+ are received. Note there is currently no timeout on reading the response
+ body.
+
+ SynapseError: if the response is not a 2xx, the remote file is too large, or
+ another exception happens during the download.
"""
actual_headers = {b"User-Agent": [self.user_agent]}
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 987765e877..dce6c4d168 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
- async def _download_url(self, url, user):
+ async def _download_url(self, url: str, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead.
- url_to_download = url
+ url_to_download = url # type: Optional[str]
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
@@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
expires = ONE_HOUR
- etag = headers["ETag"][0] if "ETag" in headers else None
+ etag = (
+ headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+ )
else:
- html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ # we can only get here if we did an oembed request and have an oembed_result.html
+ assert oembed_result.html is not None
+ assert oembed_url is not None
+
+ html_bytes = oembed_result.html.encode("utf-8")
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
f.write(html_bytes)
await finish()
|