summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/proxyagent.py184
1 files changed, 109 insertions, 75 deletions
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index f7193e60bd..19e987f118 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -14,21 +14,32 @@
 import base64
 import logging
 import re
-from typing import Optional, Tuple
-from urllib.request import getproxies_environment, proxy_bypass_environment
+from typing import Any, Dict, Optional, Tuple
+from urllib.parse import urlparse
+from urllib.request import (  # type: ignore[attr-defined]
+    getproxies_environment,
+    proxy_bypass_environment,
+)
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
 from twisted.python.failure import Failure
-from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.client import (
+    URI,
+    BrowserLikePolicyForHTTPS,
+    HTTPConnectionPool,
+    _AgentBase,
+)
 from twisted.web.error import SchemeNotSupported
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IPolicyForHTTPS
+from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS
 
 from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+from synapse.types import ISynapseReactor
 
 logger = logging.getLogger(__name__)
 
@@ -63,35 +74,38 @@ class ProxyAgent(_AgentBase):
                        reactor might have some blacklisting applied (i.e. for DNS queries),
                        but we need unblocked access to the proxy.
 
-        contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+        contextFactory: A factory for TLS contexts, to control the
             verification parameters of OpenSSL.  The default is to use a
             `BrowserLikePolicyForHTTPS`, so unless you have special
             requirements you can leave this as-is.
 
-        connectTimeout (Optional[float]): The amount of time that this Agent will wait
+        connectTimeout: The amount of time that this Agent will wait
             for the peer to accept a connection, in seconds. If 'None',
             HostnameEndpoint's default (30s) will be used.
-
             This is used for connections to both proxies and destination servers.
 
-        bindAddress (bytes): The local address for client sockets to bind to.
+        bindAddress: The local address for client sockets to bind to.
 
-        pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+        pool: connection pool to be used. If None, a
             non-persistent pool instance will be created.
 
-        use_proxy (bool): Whether proxy settings should be discovered and used
+        use_proxy: Whether proxy settings should be discovered and used
             from conventional environment variables.
+
+    Raises:
+        ValueError if use_proxy is set and the environment variables
+            contain an invalid proxy specification.
     """
 
     def __init__(
         self,
-        reactor,
-        proxy_reactor=None,
+        reactor: IReactorCore,
+        proxy_reactor: Optional[ISynapseReactor] = None,
         contextFactory: Optional[IPolicyForHTTPS] = None,
-        connectTimeout=None,
-        bindAddress=None,
-        pool=None,
-        use_proxy=False,
+        connectTimeout: Optional[float] = None,
+        bindAddress: Optional[bytes] = None,
+        pool: Optional[HTTPConnectionPool] = None,
+        use_proxy: bool = False,
     ):
         contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
 
@@ -102,7 +116,7 @@ class ProxyAgent(_AgentBase):
         else:
             self.proxy_reactor = proxy_reactor
 
-        self._endpoint_kwargs = {}
+        self._endpoint_kwargs: Dict[str, Any] = {}
         if connectTimeout is not None:
             self._endpoint_kwargs["timeout"] = connectTimeout
         if bindAddress is not None:
@@ -117,16 +131,12 @@ class ProxyAgent(_AgentBase):
             https_proxy = proxies["https"].encode() if "https" in proxies else None
             no_proxy = proxies["no"] if "no" in proxies else None
 
-        # Parse credentials from http and https proxy connection string if present
-        self.http_proxy_creds, http_proxy = parse_username_password(http_proxy)
-        self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
-
-        self.http_proxy_endpoint = _http_proxy_endpoint(
-            http_proxy, self.proxy_reactor, **self._endpoint_kwargs
+        self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint(
+            http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
         )
 
-        self.https_proxy_endpoint = _http_proxy_endpoint(
-            https_proxy, self.proxy_reactor, **self._endpoint_kwargs
+        self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint(
+            https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
         )
 
         self.no_proxy = no_proxy
@@ -134,7 +144,13 @@ class ProxyAgent(_AgentBase):
         self._policy_for_https = contextFactory
         self._reactor = reactor
 
-    def request(self, method, uri, headers=None, bodyProducer=None):
+    def request(
+        self,
+        method: bytes,
+        uri: bytes,
+        headers: Optional[Headers] = None,
+        bodyProducer: Optional[IBodyProducer] = None,
+    ) -> defer.Deferred:
         """
         Issue a request to the server indicated by the given uri.
 
@@ -146,16 +162,15 @@ class ProxyAgent(_AgentBase):
         See also: twisted.web.iweb.IAgent.request
 
         Args:
-            method (bytes): The request method to use, such as `GET`, `POST`, etc
+            method: The request method to use, such as `GET`, `POST`, etc
 
-            uri (bytes): The location of the resource to request.
+            uri: The location of the resource to request.
 
-            headers (Headers|None): Extra headers to send with the request
+            headers: Extra headers to send with the request
 
-            bodyProducer (IBodyProducer|None): An object which can generate bytes to
-                make up the body of this request (for example, the properly encoded
-                contents of a file for a file upload). Or, None if the request is to
-                have no body.
+            bodyProducer: An object which can generate bytes to make up the body of
+                this request (for example, the properly encoded contents of a file for
+                a file upload). Or, None if the request is to have no body.
 
         Returns:
             Deferred[IResponse]: completes when the header of the response has
@@ -253,70 +268,89 @@ class ProxyAgent(_AgentBase):
         )
 
 
-def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs):
+def _http_proxy_endpoint(
+    proxy: Optional[bytes],
+    reactor: IReactorCore,
+    tls_options_factory: IPolicyForHTTPS,
+    **kwargs,
+) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
     """Parses an http proxy setting and returns an endpoint for the proxy
 
     Args:
-        proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>]
-            Note that compared to other apps, this function currently lacks support
-            for specifying a protocol schema (i.e. protocol://...).
+        proxy: the proxy setting in the form: [scheme://][<username>:<password>@]<host>[:<port>]
+            This currently supports http:// and https:// proxies.
+            A hostname without scheme is assumed to be http.
 
         reactor: reactor to be used to connect to the proxy
 
+        tls_options_factory: the TLS options to use when connecting through a https proxy
+
         kwargs: other args to be passed to HostnameEndpoint
 
     Returns:
-        interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
-            or None
+        a tuple of
+            endpoint to use to connect to the proxy, or None
+            ProxyCredentials or if no credentials were found, or None
+
+    Raise:
+        ValueError if proxy has no hostname or unsupported scheme.
     """
     if proxy is None:
-        return None
+        return None, None
 
-    # Parse the connection string
-    host, port = parse_host_port(proxy, default_port=1080)
-    return HostnameEndpoint(reactor, host, port, **kwargs)
+    # Note: urlsplit/urlparse cannot be used here as that does not work (for Python
+    # 3.9+) on scheme-less proxies, e.g. host:port.
+    scheme, host, port, credentials = parse_proxy(proxy)
 
+    proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs)
 
-def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]:
-    """
-    Parses the username and password from a proxy declaration e.g
-    username:password@hostname:port.
+    if scheme == b"https":
+        tls_options = tls_options_factory.creatorForNetloc(host, port)
+        proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint)
 
-    Args:
-        proxy: The proxy connection string.
+    return proxy_endpoint, credentials
 
-    Returns
-        An instance of ProxyCredentials and the proxy connection string with any credentials
-        stripped, i.e u:p@host:port -> host:port. If no credentials were found, the
-        ProxyCredentials instance is replaced with None.
-    """
-    if proxy and b"@" in proxy:
-        # We use rsplit here as the password could contain an @ character
-        credentials, proxy_without_credentials = proxy.rsplit(b"@", 1)
-        return ProxyCredentials(credentials), proxy_without_credentials
 
-    return None, proxy
+def parse_proxy(
+    proxy: bytes, default_scheme: bytes = b"http", default_port: int = 1080
+) -> Tuple[bytes, bytes, int, Optional[ProxyCredentials]]:
+    """
+    Parse a proxy connection string.
 
+    Given a HTTP proxy URL, breaks it down into components and checks that it
+    has a hostname (otherwise it is not useful to us when trying to find a
+    proxy) and asserts that the URL has a scheme we support.
 
-def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]:
-    """
-    Parse the hostname and port from a proxy connection byte string.
 
     Args:
-        hostport: The proxy connection string. Must be in the form 'host[:port]'.
-        default_port: The default port to return if one is not found in `hostport`.
+        proxy: The proxy connection string. Must be in the form '[scheme://][<username>:<password>@]host[:port]'.
+        default_scheme: The default scheme to return if one is not found in `proxy`. Defaults to http
+        default_port: The default port to return if one is not found in `proxy`. Defaults to 1080
 
     Returns:
-        A tuple containing the hostname and port. Uses `default_port` if one was not found.
+        A tuple containing the scheme, hostname, port and ProxyCredentials.
+            If no credentials were found, the ProxyCredentials instance is replaced with None.
+
+    Raise:
+        ValueError if proxy has no hostname or unsupported scheme.
     """
-    if b":" in hostport:
-        host, port = hostport.rsplit(b":", 1)
-        try:
-            port = int(port)
-            return host, port
-        except ValueError:
-            # the thing after the : wasn't a valid port; presumably this is an
-            # IPv6 address.
-            pass
+    # First check if we have a scheme present
+    # Note: urlsplit/urlparse cannot be used (for Python # 3.9+) on scheme-less proxies, e.g. host:port.
+    if b"://" not in proxy:
+        proxy = b"".join([default_scheme, b"://", proxy])
+
+    url = urlparse(proxy)
+
+    if not url.hostname:
+        raise ValueError("Proxy URL did not contain a hostname! Please specify one.")
+
+    if url.scheme not in (b"http", b"https"):
+        raise ValueError(
+            f"Unknown proxy scheme {url.scheme!s}; only 'http' and 'https' is supported."
+        )
+
+    credentials = None
+    if url.username and url.password:
+        credentials = ProxyCredentials(b"".join([url.username, b":", url.password]))
 
-    return hostport, default_port
+    return url.scheme, url.hostname, url.port or default_port, credentials