summary refs log tree commit diff
path: root/synapse/http/client.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/http/client.py211
1 files changed, 128 insertions, 83 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index f409368802..e5b13593f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,9 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import urllib
+import urllib.parse
 from io import BytesIO
 from typing import (
+    TYPE_CHECKING,
     Any,
     BinaryIO,
     Dict,
@@ -31,7 +32,7 @@ from typing import (
 
 import treq
 from canonicaljson import encode_canonical_json
-from netaddr import IPAddress
+from netaddr import IPAddress, IPSet
 from prometheus_client import Counter
 from zope.interface import implementer, provider
 
@@ -39,6 +40,8 @@ from OpenSSL import SSL
 from OpenSSL.SSL import VERIFY_NONE
 from twisted.internet import defer, error as twisted_error, protocol, ssl
 from twisted.internet.interfaces import (
+    IAddress,
+    IHostResolution,
     IReactorPluggableNameResolver,
     IResolutionReceiver,
 )
@@ -53,7 +56,7 @@ from twisted.web.client import (
 )
 from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IBodyProducer, IResponse
 
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags
 from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
 QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
 
 
-def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
+def check_against_blacklist(
+    ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
+) -> bool:
     """
+    Compares an IP address to allowed and disallowed IP sets.
+
     Args:
-        ip_address (netaddr.IPAddress)
-        ip_whitelist (netaddr.IPSet)
-        ip_blacklist (netaddr.IPSet)
+        ip_address: The IP address to check
+        ip_whitelist: Allowed IP addresses.
+        ip_blacklist: Disallowed IP addresses.
+
+    Returns:
+        True if the IP address is in the blacklist and not in the whitelist.
     """
     if ip_address in ip_blacklist:
         if ip_whitelist is None or ip_address not in ip_whitelist:
@@ -118,23 +131,30 @@ class IPBlacklistingResolver:
     addresses, preventing DNS rebinding attacks on URL preview.
     """
 
-    def __init__(self, reactor, ip_whitelist, ip_blacklist):
+    def __init__(
+        self,
+        reactor: IReactorPluggableNameResolver,
+        ip_whitelist: Optional[IPSet],
+        ip_blacklist: IPSet,
+    ):
         """
         Args:
-            reactor (twisted.internet.reactor)
-            ip_whitelist (netaddr.IPSet)
-            ip_blacklist (netaddr.IPSet)
+            reactor: The twisted reactor.
+            ip_whitelist: IP addresses to allow.
+            ip_blacklist: IP addresses to disallow.
         """
         self._reactor = reactor
         self._ip_whitelist = ip_whitelist
         self._ip_blacklist = ip_blacklist
 
-    def resolveHostName(self, recv, hostname, portNumber=0):
+    def resolveHostName(
+        self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
+    ) -> IResolutionReceiver:
 
         r = recv()
-        addresses = []
+        addresses = []  # type: List[IAddress]
 
-        def _callback():
+        def _callback() -> None:
             r.resolutionBegan(None)
 
             has_bad_ip = False
@@ -161,15 +181,15 @@ class IPBlacklistingResolver:
         @provider(IResolutionReceiver)
         class EndpointReceiver:
             @staticmethod
-            def resolutionBegan(resolutionInProgress):
+            def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
                 pass
 
             @staticmethod
-            def addressResolved(address):
+            def addressResolved(address: IAddress) -> None:
                 addresses.append(address)
 
             @staticmethod
-            def resolutionComplete():
+            def resolutionComplete() -> None:
                 _callback()
 
         self._reactor.nameResolver.resolveHostName(
@@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
     directly (without an IP address lookup).
     """
 
-    def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
+    def __init__(
+        self,
+        agent: IAgent,
+        ip_whitelist: Optional[IPSet] = None,
+        ip_blacklist: Optional[IPSet] = None,
+    ):
         """
         Args:
-            agent (twisted.web.client.Agent): The Agent to wrap.
-            reactor (twisted.internet.reactor)
-            ip_whitelist (netaddr.IPSet)
-            ip_blacklist (netaddr.IPSet)
+            agent: The Agent to wrap.
+            ip_whitelist: IP addresses to allow.
+            ip_blacklist: IP addresses to disallow.
         """
         self._agent = agent
         self._ip_whitelist = ip_whitelist
         self._ip_blacklist = ip_blacklist
 
-    def request(self, method, uri, headers=None, bodyProducer=None):
+    def request(
+        self,
+        method: bytes,
+        uri: bytes,
+        headers: Optional[Headers] = None,
+        bodyProducer: Optional[IBodyProducer] = None,
+    ) -> defer.Deferred:
         h = urllib.parse.urlparse(uri.decode("ascii"))
 
         try:
@@ -226,23 +256,23 @@ class SimpleHttpClient:
 
     def __init__(
         self,
-        hs,
-        treq_args={},
-        ip_whitelist=None,
-        ip_blacklist=None,
-        http_proxy=None,
-        https_proxy=None,
+        hs: "HomeServer",
+        treq_args: Dict[str, Any] = {},
+        ip_whitelist: Optional[IPSet] = None,
+        ip_blacklist: Optional[IPSet] = None,
+        http_proxy: Optional[bytes] = None,
+        https_proxy: Optional[bytes] = None,
     ):
         """
         Args:
-            hs (synapse.server.HomeServer)
-            treq_args (dict): Extra keyword arguments to be given to treq.request.
-            ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
+            hs
+            treq_args: Extra keyword arguments to be given to treq.request.
+            ip_blacklist: The IP addresses that are blacklisted that
                 we may not request.
-            ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
+            ip_whitelist: The whitelisted IP addresses, that we can
                request if it were otherwise caught in a blacklist.
-            http_proxy (bytes): proxy server to use for http connections. host[:port]
-            https_proxy (bytes): proxy server to use for https connections. host[:port]
+            http_proxy: proxy server to use for http connections. host[:port]
+            https_proxy: proxy server to use for https connections. host[:port]
         """
         self.hs = hs
 
@@ -306,7 +336,6 @@ class SimpleHttpClient:
             # by the DNS resolution.
             self.agent = BlacklistingAgentWrapper(
                 self.agent,
-                self.reactor,
                 ip_whitelist=self._ip_whitelist,
                 ip_blacklist=self._ip_blacklist,
             )
@@ -397,7 +426,7 @@ class SimpleHttpClient:
     async def post_urlencoded_get_json(
         self,
         uri: str,
-        args: Mapping[str, Union[str, List[str]]] = {},
+        args: Optional[Mapping[str, Union[str, List[str]]]] = None,
         headers: Optional[RawHeaders] = None,
     ) -> Any:
         """
@@ -422,9 +451,7 @@ class SimpleHttpClient:
         # TODO: Do we ever want to log message contents?
         logger.debug("post_urlencoded_get_json args: %s", args)
 
-        query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
-            "utf8"
-        )
+        query_bytes = encode_query_args(args)
 
         actual_headers = {
             b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -432,7 +459,7 @@ class SimpleHttpClient:
             b"Accept": [b"application/json"],
         }
         if headers:
-            actual_headers.update(headers)
+            actual_headers.update(headers)  # type: ignore
 
         response = await self.request(
             "POST", uri, headers=Headers(actual_headers), data=query_bytes
@@ -479,7 +506,7 @@ class SimpleHttpClient:
             b"Accept": [b"application/json"],
         }
         if headers:
-            actual_headers.update(headers)
+            actual_headers.update(headers)  # type: ignore
 
         response = await self.request(
             "POST", uri, headers=Headers(actual_headers), data=json_str
@@ -495,7 +522,10 @@ class SimpleHttpClient:
             )
 
     async def get_json(
-        self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+        self,
+        uri: str,
+        args: Optional[QueryParams] = None,
+        headers: Optional[RawHeaders] = None,
     ) -> Any:
         """Gets some json from the given URI.
 
@@ -516,7 +546,7 @@ class SimpleHttpClient:
         """
         actual_headers = {b"Accept": [b"application/json"]}
         if headers:
-            actual_headers.update(headers)
+            actual_headers.update(headers)  # type: ignore
 
         body = await self.get_raw(uri, args, headers=headers)
         return json_decoder.decode(body.decode("utf-8"))
@@ -525,7 +555,7 @@ class SimpleHttpClient:
         self,
         uri: str,
         json_body: Any,
-        args: QueryParams = {},
+        args: Optional[QueryParams] = None,
         headers: RawHeaders = None,
     ) -> Any:
         """Puts some json to the given URI.
@@ -546,9 +576,9 @@ class SimpleHttpClient:
 
             ValueError: if the response was not JSON
         """
-        if len(args):
-            query_bytes = urllib.parse.urlencode(args, True)
-            uri = "%s?%s" % (uri, query_bytes)
+        if args:
+            query_str = urllib.parse.urlencode(args, True)
+            uri = "%s?%s" % (uri, query_str)
 
         json_str = encode_canonical_json(json_body)
 
@@ -558,7 +588,7 @@ class SimpleHttpClient:
             b"Accept": [b"application/json"],
         }
         if headers:
-            actual_headers.update(headers)
+            actual_headers.update(headers)  # type: ignore
 
         response = await self.request(
             "PUT", uri, headers=Headers(actual_headers), data=json_str
@@ -574,7 +604,10 @@ class SimpleHttpClient:
             )
 
     async def get_raw(
-        self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+        self,
+        uri: str,
+        args: Optional[QueryParams] = None,
+        headers: Optional[RawHeaders] = None,
     ) -> bytes:
         """Gets raw text from the given URI.
 
@@ -592,13 +625,13 @@ class SimpleHttpClient:
 
             HttpResponseException on a non-2xx HTTP response.
         """
-        if len(args):
-            query_bytes = urllib.parse.urlencode(args, True)
-            uri = "%s?%s" % (uri, query_bytes)
+        if args:
+            query_str = urllib.parse.urlencode(args, True)
+            uri = "%s?%s" % (uri, query_str)
 
         actual_headers = {b"User-Agent": [self.user_agent]}
         if headers:
-            actual_headers.update(headers)
+            actual_headers.update(headers)  # type: ignore
 
         response = await self.request("GET", uri, headers=Headers(actual_headers))
 
@@ -641,7 +674,7 @@ class SimpleHttpClient:
 
         actual_headers = {b"User-Agent": [self.user_agent]}
         if headers:
-            actual_headers.update(headers)
+            actual_headers.update(headers)  # type: ignore
 
         response = await self.request("GET", url, headers=Headers(actual_headers))
 
@@ -649,12 +682,13 @@ class SimpleHttpClient:
 
         if (
             b"Content-Length" in resp_headers
+            and max_size
             and int(resp_headers[b"Content-Length"][0]) > max_size
         ):
-            logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
+            logger.warning("Requested URL is too large > %r bytes" % (max_size,))
             raise SynapseError(
                 502,
-                "Requested file is too large > %r bytes" % (self.max_size,),
+                "Requested file is too large > %r bytes" % (max_size,),
                 Codes.TOO_LARGE,
             )
 
@@ -668,7 +702,7 @@ class SimpleHttpClient:
 
         try:
             length = await make_deferred_yieldable(
-                _readBodyToFile(response, output_stream, max_size)
+                readBodyToFile(response, output_stream, max_size)
             )
         except SynapseError:
             # This can happen e.g. because the body is too large.
@@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
     return f
 
 
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
-
-
 class _ReadBodyToFileProtocol(protocol.Protocol):
-    def __init__(self, stream, deferred, max_size):
+    def __init__(
+        self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
+    ):
         self.stream = stream
         self.deferred = deferred
         self.length = 0
         self.max_size = max_size
 
-    def dataReceived(self, data):
+    def dataReceived(self, data: bytes) -> None:
         self.stream.write(data)
         self.length += len(data)
         if self.max_size is not None and self.length >= self.max_size:
@@ -721,7 +753,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
             self.deferred = defer.Deferred()
             self.transport.loseConnection()
 
-    def connectionLost(self, reason):
+    def connectionLost(self, reason: Failure) -> None:
         if reason.check(ResponseDone):
             self.deferred.callback(self.length)
         elif reason.check(PotentialDataLoss):
@@ -732,35 +764,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
             self.deferred.errback(reason)
 
 
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
+def readBodyToFile(
+    response: IResponse, stream: BinaryIO, max_size: Optional[int]
+) -> defer.Deferred:
+    """
+    Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
 
+    Args:
+        response: The HTTP response to read from.
+        stream: The file-object to write to.
+        max_size: The maximum file size to allow.
+
+    Returns:
+        A Deferred which resolves to the length of the read body.
+    """
 
-def _readBodyToFile(response, stream, max_size):
     d = defer.Deferred()
     response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
     return d
 
 
-def encode_urlencode_args(args):
-    return {k: encode_urlencode_arg(v) for k, v in args.items()}
+def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
+    """
+    Encodes a map of query arguments to bytes which can be appended to a URL.
 
+    Args:
+        args: The query arguments, a mapping of string to string or list of strings.
+
+    Returns:
+        The query arguments encoded as bytes.
+    """
+    if args is None:
+        return b""
 
-def encode_urlencode_arg(arg):
-    if isinstance(arg, str):
-        return arg.encode("utf-8")
-    elif isinstance(arg, list):
-        return [encode_urlencode_arg(i) for i in arg]
-    else:
-        return arg
+    encoded_args = {}
+    for k, vs in args.items():
+        if isinstance(vs, str):
+            vs = [vs]
+        encoded_args[k] = [v.encode("utf8") for v in vs]
 
+    query_str = urllib.parse.urlencode(encoded_args, True)
 
-def _print_ex(e):
-    if hasattr(e, "reasons") and e.reasons:
-        for ex in e.reasons:
-            _print_ex(ex)
-    else:
-        logger.exception(e)
+    return query_str.encode("utf8")
 
 
 class InsecureInterceptableContextFactory(ssl.ContextFactory):