summary refs log tree commit diff
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2021-07-27 18:31:06 +0200
committerGitHub <noreply@github.com>2021-07-27 17:31:06 +0100
commit076deade028613da56391758305d645edeab40e5 (patch)
tree6f54f92b520aa2370646af34f6734893d2bbedfd
parentAdd a PeriodicallyFlushingMemoryHandler to prevent logging silence (#10407) (diff)
downloadsynapse-076deade028613da56391758305d645edeab40e5.tar.xz
allow specifying https:// proxy (#10411)
Diffstat (limited to '')
-rw-r--r--changelog.d/10411.feature1
-rw-r--r--synapse/http/proxyagent.py184
-rw-r--r--tests/http/test_proxyagent.py398
3 files changed, 450 insertions, 133 deletions
diff --git a/changelog.d/10411.feature b/changelog.d/10411.feature
new file mode 100644
index 0000000000..ef0ab84b17
--- /dev/null
+++ b/changelog.d/10411.feature
@@ -0,0 +1 @@
+Add support for https connections to a proxy server. Contributed by @Bubu and @dklimpel.
\ No newline at end of file
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index f7193e60bd..19e987f118 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -14,21 +14,32 @@
 import base64
 import logging
 import re
-from typing import Optional, Tuple
-from urllib.request import getproxies_environment, proxy_bypass_environment
+from typing import Any, Dict, Optional, Tuple
+from urllib.parse import urlparse
+from urllib.request import (  # type: ignore[attr-defined]
+    getproxies_environment,
+    proxy_bypass_environment,
+)
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
 from twisted.python.failure import Failure
-from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.client import (
+    URI,
+    BrowserLikePolicyForHTTPS,
+    HTTPConnectionPool,
+    _AgentBase,
+)
 from twisted.web.error import SchemeNotSupported
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IPolicyForHTTPS
+from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS
 
 from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+from synapse.types import ISynapseReactor
 
 logger = logging.getLogger(__name__)
 
@@ -63,35 +74,38 @@ class ProxyAgent(_AgentBase):
                        reactor might have some blacklisting applied (i.e. for DNS queries),
                        but we need unblocked access to the proxy.
 
-        contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+        contextFactory: A factory for TLS contexts, to control the
             verification parameters of OpenSSL.  The default is to use a
             `BrowserLikePolicyForHTTPS`, so unless you have special
             requirements you can leave this as-is.
 
-        connectTimeout (Optional[float]): The amount of time that this Agent will wait
+        connectTimeout: The amount of time that this Agent will wait
             for the peer to accept a connection, in seconds. If 'None',
             HostnameEndpoint's default (30s) will be used.
-
             This is used for connections to both proxies and destination servers.
 
-        bindAddress (bytes): The local address for client sockets to bind to.
+        bindAddress: The local address for client sockets to bind to.
 
-        pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+        pool: connection pool to be used. If None, a
             non-persistent pool instance will be created.
 
-        use_proxy (bool): Whether proxy settings should be discovered and used
+        use_proxy: Whether proxy settings should be discovered and used
             from conventional environment variables.
+
+    Raises:
+        ValueError if use_proxy is set and the environment variables
+            contain an invalid proxy specification.
     """
 
     def __init__(
         self,
-        reactor,
-        proxy_reactor=None,
+        reactor: IReactorCore,
+        proxy_reactor: Optional[ISynapseReactor] = None,
         contextFactory: Optional[IPolicyForHTTPS] = None,
-        connectTimeout=None,
-        bindAddress=None,
-        pool=None,
-        use_proxy=False,
+        connectTimeout: Optional[float] = None,
+        bindAddress: Optional[bytes] = None,
+        pool: Optional[HTTPConnectionPool] = None,
+        use_proxy: bool = False,
     ):
         contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
 
@@ -102,7 +116,7 @@ class ProxyAgent(_AgentBase):
         else:
             self.proxy_reactor = proxy_reactor
 
-        self._endpoint_kwargs = {}
+        self._endpoint_kwargs: Dict[str, Any] = {}
         if connectTimeout is not None:
             self._endpoint_kwargs["timeout"] = connectTimeout
         if bindAddress is not None:
@@ -117,16 +131,12 @@ class ProxyAgent(_AgentBase):
             https_proxy = proxies["https"].encode() if "https" in proxies else None
             no_proxy = proxies["no"] if "no" in proxies else None
 
-        # Parse credentials from http and https proxy connection string if present
-        self.http_proxy_creds, http_proxy = parse_username_password(http_proxy)
-        self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
-
-        self.http_proxy_endpoint = _http_proxy_endpoint(
-            http_proxy, self.proxy_reactor, **self._endpoint_kwargs
+        self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint(
+            http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
         )
 
-        self.https_proxy_endpoint = _http_proxy_endpoint(
-            https_proxy, self.proxy_reactor, **self._endpoint_kwargs
+        self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint(
+            https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
         )
 
         self.no_proxy = no_proxy
@@ -134,7 +144,13 @@ class ProxyAgent(_AgentBase):
         self._policy_for_https = contextFactory
         self._reactor = reactor
 
-    def request(self, method, uri, headers=None, bodyProducer=None):
+    def request(
+        self,
+        method: bytes,
+        uri: bytes,
+        headers: Optional[Headers] = None,
+        bodyProducer: Optional[IBodyProducer] = None,
+    ) -> defer.Deferred:
         """
         Issue a request to the server indicated by the given uri.
 
@@ -146,16 +162,15 @@ class ProxyAgent(_AgentBase):
         See also: twisted.web.iweb.IAgent.request
 
         Args:
-            method (bytes): The request method to use, such as `GET`, `POST`, etc
+            method: The request method to use, such as `GET`, `POST`, etc
 
-            uri (bytes): The location of the resource to request.
+            uri: The location of the resource to request.
 
-            headers (Headers|None): Extra headers to send with the request
+            headers: Extra headers to send with the request
 
-            bodyProducer (IBodyProducer|None): An object which can generate bytes to
-                make up the body of this request (for example, the properly encoded
-                contents of a file for a file upload). Or, None if the request is to
-                have no body.
+            bodyProducer: An object which can generate bytes to make up the body of
+                this request (for example, the properly encoded contents of a file for
+                a file upload). Or, None if the request is to have no body.
 
         Returns:
             Deferred[IResponse]: completes when the header of the response has
@@ -253,70 +268,89 @@ class ProxyAgent(_AgentBase):
         )
 
 
-def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs):
+def _http_proxy_endpoint(
+    proxy: Optional[bytes],
+    reactor: IReactorCore,
+    tls_options_factory: IPolicyForHTTPS,
+    **kwargs,
+) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
     """Parses an http proxy setting and returns an endpoint for the proxy
 
     Args:
-        proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>]
-            Note that compared to other apps, this function currently lacks support
-            for specifying a protocol schema (i.e. protocol://...).
+        proxy: the proxy setting in the form: [scheme://][<username>:<password>@]<host>[:<port>]
+            This currently supports http:// and https:// proxies.
+            A hostname without scheme is assumed to be http.
 
         reactor: reactor to be used to connect to the proxy
 
+        tls_options_factory: the TLS options to use when connecting through a https proxy
+
         kwargs: other args to be passed to HostnameEndpoint
 
     Returns:
-        interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
-            or None
+        a tuple of
+            endpoint to use to connect to the proxy, or None
+            ProxyCredentials or if no credentials were found, or None
+
+    Raise:
+        ValueError if proxy has no hostname or unsupported scheme.
     """
     if proxy is None:
-        return None
+        return None, None
 
-    # Parse the connection string
-    host, port = parse_host_port(proxy, default_port=1080)
-    return HostnameEndpoint(reactor, host, port, **kwargs)
+    # Note: urlsplit/urlparse cannot be used here as that does not work (for Python
+    # 3.9+) on scheme-less proxies, e.g. host:port.
+    scheme, host, port, credentials = parse_proxy(proxy)
 
+    proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs)
 
-def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]:
-    """
-    Parses the username and password from a proxy declaration e.g
-    username:password@hostname:port.
+    if scheme == b"https":
+        tls_options = tls_options_factory.creatorForNetloc(host, port)
+        proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint)
 
-    Args:
-        proxy: The proxy connection string.
+    return proxy_endpoint, credentials
 
-    Returns
-        An instance of ProxyCredentials and the proxy connection string with any credentials
-        stripped, i.e u:p@host:port -> host:port. If no credentials were found, the
-        ProxyCredentials instance is replaced with None.
-    """
-    if proxy and b"@" in proxy:
-        # We use rsplit here as the password could contain an @ character
-        credentials, proxy_without_credentials = proxy.rsplit(b"@", 1)
-        return ProxyCredentials(credentials), proxy_without_credentials
 
-    return None, proxy
+def parse_proxy(
+    proxy: bytes, default_scheme: bytes = b"http", default_port: int = 1080
+) -> Tuple[bytes, bytes, int, Optional[ProxyCredentials]]:
+    """
+    Parse a proxy connection string.
 
+    Given a HTTP proxy URL, breaks it down into components and checks that it
+    has a hostname (otherwise it is not useful to us when trying to find a
+    proxy) and asserts that the URL has a scheme we support.
 
-def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]:
-    """
-    Parse the hostname and port from a proxy connection byte string.
 
     Args:
-        hostport: The proxy connection string. Must be in the form 'host[:port]'.
-        default_port: The default port to return if one is not found in `hostport`.
+        proxy: The proxy connection string. Must be in the form '[scheme://][<username>:<password>@]host[:port]'.
+        default_scheme: The default scheme to return if one is not found in `proxy`. Defaults to http
+        default_port: The default port to return if one is not found in `proxy`. Defaults to 1080
 
     Returns:
-        A tuple containing the hostname and port. Uses `default_port` if one was not found.
+        A tuple containing the scheme, hostname, port and ProxyCredentials.
+            If no credentials were found, the ProxyCredentials instance is replaced with None.
+
+    Raise:
+        ValueError if proxy has no hostname or unsupported scheme.
     """
-    if b":" in hostport:
-        host, port = hostport.rsplit(b":", 1)
-        try:
-            port = int(port)
-            return host, port
-        except ValueError:
-            # the thing after the : wasn't a valid port; presumably this is an
-            # IPv6 address.
-            pass
+    # First check if we have a scheme present
+    # Note: urlsplit/urlparse cannot be used (for Python # 3.9+) on scheme-less proxies, e.g. host:port.
+    if b"://" not in proxy:
+        proxy = b"".join([default_scheme, b"://", proxy])
+
+    url = urlparse(proxy)
+
+    if not url.hostname:
+        raise ValueError("Proxy URL did not contain a hostname! Please specify one.")
+
+    if url.scheme not in (b"http", b"https"):
+        raise ValueError(
+            f"Unknown proxy scheme {url.scheme!s}; only 'http' and 'https' is supported."
+        )
+
+    credentials = None
+    if url.username and url.password:
+        credentials = ProxyCredentials(b"".join([url.username, b":", url.password]))
 
-    return hostport, default_port
+    return url.scheme, url.hostname, url.port or default_port, credentials
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 437113929a..e5865c161d 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -14,19 +14,22 @@
 import base64
 import logging
 import os
-from typing import Optional
+from typing import Iterable, Optional
 from unittest.mock import patch
 
 import treq
 from netaddr import IPSet
+from parameterized import parameterized
 
 from twisted.internet import interfaces  # noqa: F401
+from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint
+from twisted.internet.interfaces import IProtocol, IProtocolFactory
 from twisted.internet.protocol import Factory
-from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.web.http import HTTPChannel
 
 from synapse.http.client import BlacklistingReactorWrapper
-from synapse.http.proxyagent import ProxyAgent
+from synapse.http.proxyagent import ProxyAgent, ProxyCredentials, parse_proxy
 
 from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
 from tests.server import FakeTransport, ThreadedMemoryReactorClock
@@ -37,33 +40,208 @@ logger = logging.getLogger(__name__)
 HTTPFactory = Factory.forProtocol(HTTPChannel)
 
 
+class ProxyParserTests(TestCase):
+    """
+    Values for test
+    [
+        proxy_string,
+        expected_scheme,
+        expected_hostname,
+        expected_port,
+        expected_credentials,
+    ]
+    """
+
+    @parameterized.expand(
+        [
+            # host
+            [b"localhost", b"http", b"localhost", 1080, None],
+            [b"localhost:9988", b"http", b"localhost", 9988, None],
+            # host+scheme
+            [b"https://localhost", b"https", b"localhost", 1080, None],
+            [b"https://localhost:1234", b"https", b"localhost", 1234, None],
+            # ipv4
+            [b"1.2.3.4", b"http", b"1.2.3.4", 1080, None],
+            [b"1.2.3.4:9988", b"http", b"1.2.3.4", 9988, None],
+            # ipv4+scheme
+            [b"https://1.2.3.4", b"https", b"1.2.3.4", 1080, None],
+            [b"https://1.2.3.4:9988", b"https", b"1.2.3.4", 9988, None],
+            # ipv6 - without brackets is broken
+            # [
+            #     b"2001:0db8:85a3:0000:0000:8a2e:0370:effe",
+            #     b"http",
+            #     b"2001:0db8:85a3:0000:0000:8a2e:0370:effe",
+            #     1080,
+            #     None,
+            # ],
+            # [
+            #     b"2001:0db8:85a3:0000:0000:8a2e:0370:1234",
+            #     b"http",
+            #     b"2001:0db8:85a3:0000:0000:8a2e:0370:1234",
+            #     1080,
+            #     None,
+            # ],
+            # [b"::1", b"http", b"::1", 1080, None],
+            # [b"::ffff:0.0.0.0", b"http", b"::ffff:0.0.0.0", 1080, None],
+            # ipv6 - with brackets
+            [
+                b"[2001:0db8:85a3:0000:0000:8a2e:0370:effe]",
+                b"http",
+                b"2001:0db8:85a3:0000:0000:8a2e:0370:effe",
+                1080,
+                None,
+            ],
+            [
+                b"[2001:0db8:85a3:0000:0000:8a2e:0370:1234]",
+                b"http",
+                b"2001:0db8:85a3:0000:0000:8a2e:0370:1234",
+                1080,
+                None,
+            ],
+            [b"[::1]", b"http", b"::1", 1080, None],
+            [b"[::ffff:0.0.0.0]", b"http", b"::ffff:0.0.0.0", 1080, None],
+            # ipv6+port
+            [
+                b"[2001:0db8:85a3:0000:0000:8a2e:0370:effe]:9988",
+                b"http",
+                b"2001:0db8:85a3:0000:0000:8a2e:0370:effe",
+                9988,
+                None,
+            ],
+            [
+                b"[2001:0db8:85a3:0000:0000:8a2e:0370:1234]:9988",
+                b"http",
+                b"2001:0db8:85a3:0000:0000:8a2e:0370:1234",
+                9988,
+                None,
+            ],
+            [b"[::1]:9988", b"http", b"::1", 9988, None],
+            [b"[::ffff:0.0.0.0]:9988", b"http", b"::ffff:0.0.0.0", 9988, None],
+            # ipv6+scheme
+            [
+                b"https://[2001:0db8:85a3:0000:0000:8a2e:0370:effe]",
+                b"https",
+                b"2001:0db8:85a3:0000:0000:8a2e:0370:effe",
+                1080,
+                None,
+            ],
+            [
+                b"https://[2001:0db8:85a3:0000:0000:8a2e:0370:1234]",
+                b"https",
+                b"2001:0db8:85a3:0000:0000:8a2e:0370:1234",
+                1080,
+                None,
+            ],
+            [b"https://[::1]", b"https", b"::1", 1080, None],
+            [b"https://[::ffff:0.0.0.0]", b"https", b"::ffff:0.0.0.0", 1080, None],
+            # ipv6+scheme+port
+            [
+                b"https://[2001:0db8:85a3:0000:0000:8a2e:0370:effe]:9988",
+                b"https",
+                b"2001:0db8:85a3:0000:0000:8a2e:0370:effe",
+                9988,
+                None,
+            ],
+            [
+                b"https://[2001:0db8:85a3:0000:0000:8a2e:0370:1234]:9988",
+                b"https",
+                b"2001:0db8:85a3:0000:0000:8a2e:0370:1234",
+                9988,
+                None,
+            ],
+            [b"https://[::1]:9988", b"https", b"::1", 9988, None],
+            # with credentials
+            [
+                b"https://user:pass@1.2.3.4:9988",
+                b"https",
+                b"1.2.3.4",
+                9988,
+                b"user:pass",
+            ],
+            [b"user:pass@1.2.3.4:9988", b"http", b"1.2.3.4", 9988, b"user:pass"],
+            [
+                b"https://user:pass@proxy.local:9988",
+                b"https",
+                b"proxy.local",
+                9988,
+                b"user:pass",
+            ],
+            [
+                b"user:pass@proxy.local:9988",
+                b"http",
+                b"proxy.local",
+                9988,
+                b"user:pass",
+            ],
+        ]
+    )
+    def test_parse_proxy(
+        self,
+        proxy_string: bytes,
+        expected_scheme: bytes,
+        expected_hostname: bytes,
+        expected_port: int,
+        expected_credentials: Optional[bytes],
+    ):
+        """
+        Tests that a given proxy URL will be broken into the components.
+        Args:
+            proxy_string: The proxy connection string.
+            expected_scheme: Expected value of proxy scheme.
+            expected_hostname: Expected value of proxy hostname.
+            expected_port: Expected value of proxy port.
+            expected_credentials: Expected value of credentials.
+                Must be in form '<username>:<password>' or None
+        """
+        proxy_cred = None
+        if expected_credentials:
+            proxy_cred = ProxyCredentials(expected_credentials)
+        self.assertEqual(
+            (
+                expected_scheme,
+                expected_hostname,
+                expected_port,
+                proxy_cred,
+            ),
+            parse_proxy(proxy_string),
+        )
+
+
 class MatrixFederationAgentTests(TestCase):
     def setUp(self):
         self.reactor = ThreadedMemoryReactorClock()
 
     def _make_connection(
-        self, client_factory, server_factory, ssl=False, expected_sni=None
-    ):
+        self,
+        client_factory: IProtocolFactory,
+        server_factory: IProtocolFactory,
+        ssl: bool = False,
+        expected_sni: Optional[bytes] = None,
+        tls_sanlist: Optional[Iterable[bytes]] = None,
+    ) -> IProtocol:
         """Builds a test server, and completes the outgoing client connection
 
         Args:
-            client_factory (interfaces.IProtocolFactory): the the factory that the
+            client_factory: the the factory that the
                 application is trying to use to make the outbound connection. We will
                 invoke it to build the client Protocol
 
-            server_factory (interfaces.IProtocolFactory): a factory to build the
+            server_factory: a factory to build the
                 server-side protocol
 
-            ssl (bool): If true, we will expect an ssl connection and wrap
+            ssl: If true, we will expect an ssl connection and wrap
                 server_factory with a TLSMemoryBIOFactory
 
-            expected_sni (bytes|None): the expected SNI value
+            expected_sni: the expected SNI value
+
+            tls_sanlist: list of SAN entries for the TLS cert presented by the server.
+                 Defaults to [b'DNS:test.com']
 
         Returns:
-            IProtocol: the server Protocol returned by server_factory
+            the server Protocol returned by server_factory
         """
         if ssl:
-            server_factory = _wrap_server_factory_for_tls(server_factory)
+            server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
 
         server_protocol = server_factory.buildProtocol(None)
 
@@ -98,22 +276,28 @@ class MatrixFederationAgentTests(TestCase):
             self.assertEqual(
                 server_name,
                 expected_sni,
-                "Expected SNI %s but got %s" % (expected_sni, server_name),
+                f"Expected SNI {expected_sni!s} but got {server_name!s}",
             )
 
         return http_protocol
 
-    def _test_request_direct_connection(self, agent, scheme, hostname, path):
+    def _test_request_direct_connection(
+        self,
+        agent: ProxyAgent,
+        scheme: bytes,
+        hostname: bytes,
+        path: bytes,
+    ):
         """Runs a test case for a direct connection not going through a proxy.
 
         Args:
-            agent (ProxyAgent): the proxy agent being tested
+            agent: the proxy agent being tested
 
-            scheme (bytes): expected to be either "http" or "https"
+            scheme: expected to be either "http" or "https"
 
-            hostname (bytes): the hostname to connect to in the test
+            hostname: the hostname to connect to in the test
 
-            path (bytes): the path to connect to in the test
+            path: the path to connect to in the test
         """
         is_https = scheme == b"https"
 
@@ -208,7 +392,7 @@ class MatrixFederationAgentTests(TestCase):
         """
         Tests that requests can be made through a proxy.
         """
-        self._do_http_request_via_proxy(auth_credentials=None)
+        self._do_http_request_via_proxy(ssl=False, auth_credentials=None)
 
     @patch.dict(
         os.environ,
@@ -218,12 +402,28 @@ class MatrixFederationAgentTests(TestCase):
         """
         Tests that authenticated requests can be made through a proxy.
         """
-        self._do_http_request_via_proxy(auth_credentials="bob:pinkponies")
+        self._do_http_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies")
+
+    @patch.dict(
+        os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"}
+    )
+    def test_http_request_via_https_proxy(self):
+        self._do_http_request_via_proxy(ssl=True, auth_credentials=None)
+
+    @patch.dict(
+        os.environ,
+        {
+            "http_proxy": "https://bob:pinkponies@proxy.com:8888",
+            "no_proxy": "unused.com",
+        },
+    )
+    def test_http_request_via_https_proxy_with_auth(self):
+        self._do_http_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies")
 
     @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
     def test_https_request_via_proxy(self):
         """Tests that TLS-encrypted requests can be made through a proxy"""
-        self._do_https_request_via_proxy(auth_credentials=None)
+        self._do_https_request_via_proxy(ssl=False, auth_credentials=None)
 
     @patch.dict(
         os.environ,
@@ -231,16 +431,40 @@ class MatrixFederationAgentTests(TestCase):
     )
     def test_https_request_via_proxy_with_auth(self):
         """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
-        self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
+        self._do_https_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies")
+
+    @patch.dict(
+        os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
+    )
+    def test_https_request_via_https_proxy(self):
+        """Tests that TLS-encrypted requests can be made through a proxy"""
+        self._do_https_request_via_proxy(ssl=True, auth_credentials=None)
+
+    @patch.dict(
+        os.environ,
+        {"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
+    )
+    def test_https_request_via_https_proxy_with_auth(self):
+        """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
+        self._do_https_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies")
 
     def _do_http_request_via_proxy(
         self,
-        auth_credentials: Optional[str] = None,
+        ssl: bool = False,
+        auth_credentials: Optional[bytes] = None,
     ):
+        """Send a http request via an agent and check that it is correctly received at
+            the proxy. The proxy can use either http or https.
+        Args:
+            ssl: True if we expect the request to connect via https to proxy
+            auth_credentials: credentials to authenticate at proxy
         """
-        Tests that requests can be made through a proxy.
-        """
-        agent = ProxyAgent(self.reactor, use_proxy=True)
+        if ssl:
+            agent = ProxyAgent(
+                self.reactor, use_proxy=True, contextFactory=get_test_https_policy()
+            )
+        else:
+            agent = ProxyAgent(self.reactor, use_proxy=True)
 
         self.reactor.lookups["proxy.com"] = "1.2.3.5"
         d = agent.request(b"GET", b"http://test.com")
@@ -254,7 +478,11 @@ class MatrixFederationAgentTests(TestCase):
 
         # make a test server, and wire up the client
         http_server = self._make_connection(
-            client_factory, _get_test_protocol_factory()
+            client_factory,
+            _get_test_protocol_factory(),
+            ssl=ssl,
+            tls_sanlist=[b"DNS:proxy.com"] if ssl else None,
+            expected_sni=b"proxy.com" if ssl else None,
         )
 
         # the FakeTransport is async, so we need to pump the reactor
@@ -272,7 +500,7 @@ class MatrixFederationAgentTests(TestCase):
 
         if auth_credentials is not None:
             # Compute the correct header value for Proxy-Authorization
-            encoded_credentials = base64.b64encode(b"bob:pinkponies")
+            encoded_credentials = base64.b64encode(auth_credentials)
             expected_header_value = b"Basic " + encoded_credentials
 
             # Validate the header's value
@@ -295,8 +523,15 @@ class MatrixFederationAgentTests(TestCase):
 
     def _do_https_request_via_proxy(
         self,
-        auth_credentials: Optional[str] = None,
+        ssl: bool = False,
+        auth_credentials: Optional[bytes] = None,
     ):
+        """Send a https request via an agent and check that it is correctly received at
+            the proxy and client. The proxy can use either http or https.
+        Args:
+            ssl: True if we expect the request to connect via https to proxy
+            auth_credentials: credentials to authenticate at proxy
+        """
         agent = ProxyAgent(
             self.reactor,
             contextFactory=get_test_https_policy(),
@@ -313,18 +548,15 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(host, "1.2.3.5")
         self.assertEqual(port, 1080)
 
-        # make a test HTTP server, and wire up the client
+        # make a test server to act as the proxy, and wire up the client
         proxy_server = self._make_connection(
-            client_factory, _get_test_protocol_factory()
+            client_factory,
+            _get_test_protocol_factory(),
+            ssl=ssl,
+            tls_sanlist=[b"DNS:proxy.com"] if ssl else None,
+            expected_sni=b"proxy.com" if ssl else None,
         )
-
-        # fish the transports back out so that we can do the old switcheroo
-        s2c_transport = proxy_server.transport
-        client_protocol = s2c_transport.other
-        c2s_transport = client_protocol.transport
-
-        # the FakeTransport is async, so we need to pump the reactor
-        self.reactor.advance(0)
+        assert isinstance(proxy_server, HTTPChannel)
 
         # now there should be a pending CONNECT request
         self.assertEqual(len(proxy_server.requests), 1)
@@ -340,7 +572,7 @@ class MatrixFederationAgentTests(TestCase):
 
         if auth_credentials is not None:
             # Compute the correct header value for Proxy-Authorization
-            encoded_credentials = base64.b64encode(b"bob:pinkponies")
+            encoded_credentials = base64.b64encode(auth_credentials)
             expected_header_value = b"Basic " + encoded_credentials
 
             # Validate the header's value
@@ -352,31 +584,49 @@ class MatrixFederationAgentTests(TestCase):
         # tell the proxy server not to close the connection
         proxy_server.persistent = True
 
-        # this just stops the http Request trying to do a chunked response
-        # request.setHeader(b"Content-Length", b"0")
         request.finish()
 
-        # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
-        ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
-        ssl_protocol = ssl_factory.buildProtocol(None)
-        http_server = ssl_protocol.wrappedProtocol
+        # now we make another test server to act as the upstream HTTP server.
+        server_ssl_protocol = _wrap_server_factory_for_tls(
+            _get_test_protocol_factory()
+        ).buildProtocol(None)
 
-        ssl_protocol.makeConnection(
-            FakeTransport(client_protocol, self.reactor, ssl_protocol)
-        )
-        c2s_transport.other = ssl_protocol
+        # Tell the HTTP server to send outgoing traffic back via the proxy's transport.
+        proxy_server_transport = proxy_server.transport
+        server_ssl_protocol.makeConnection(proxy_server_transport)
+
+        # ... and replace the protocol on the proxy's transport with the
+        # TLSMemoryBIOProtocol for the test server, so that incoming traffic
+        # to the proxy gets sent over to the HTTP(s) server.
+        #
+        # This needs a bit of gut-wrenching, which is different depending on whether
+        # the proxy is using TLS or not.
+        #
+        # (an alternative, possibly more elegant, approach would be to use a custom
+        # Protocol to implement the proxy, which starts out by forwarding to an
+        # HTTPChannel (to implement the CONNECT command) and can then be switched
+        # into a mode where it forwards its traffic to another Protocol.)
+        if ssl:
+            assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol)
+            proxy_server_transport.wrappedProtocol = server_ssl_protocol
+        else:
+            assert isinstance(proxy_server_transport, FakeTransport)
+            client_protocol = proxy_server_transport.other
+            c2s_transport = client_protocol.transport
+            c2s_transport.other = server_ssl_protocol
 
         self.reactor.advance(0)
 
-        server_name = ssl_protocol._tlsConnection.get_servername()
+        server_name = server_ssl_protocol._tlsConnection.get_servername()
         expected_sni = b"test.com"
         self.assertEqual(
             server_name,
             expected_sni,
-            "Expected SNI %s but got %s" % (expected_sni, server_name),
+            f"Expected SNI {expected_sni!s} but got {server_name!s}",
         )
 
         # now there should be a pending request
+        http_server = server_ssl_protocol.wrappedProtocol
         self.assertEqual(len(http_server.requests), 1)
 
         request = http_server.requests[0]
@@ -510,7 +760,7 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(
             server_name,
             expected_sni,
-            "Expected SNI %s but got %s" % (expected_sni, server_name),
+            f"Expected SNI {expected_sni!s} but got {server_name!s}",
         )
 
         # now there should be a pending request
@@ -529,16 +779,48 @@ class MatrixFederationAgentTests(TestCase):
         body = self.successResultOf(treq.content(resp))
         self.assertEqual(body, b"result")
 
+    @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
+    def test_proxy_with_no_scheme(self):
+        http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
+        self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
+        self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
+        self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+
+    @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
+    def test_proxy_with_unsupported_scheme(self):
+        with self.assertRaises(ValueError):
+            ProxyAgent(self.reactor, use_proxy=True)
+
+    @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
+    def test_proxy_with_http_scheme(self):
+        http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
+        self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
+        self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
+        self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+
+    @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
+    def test_proxy_with_https_scheme(self):
+        https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
+        self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
+        self.assertEqual(
+            https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
+        )
+        self.assertEqual(
+            https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888
+        )
+
 
-def _wrap_server_factory_for_tls(factory, sanlist=None):
+def _wrap_server_factory_for_tls(
+    factory: IProtocolFactory, sanlist: Iterable[bytes] = None
+) -> IProtocolFactory:
     """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
 
     The resultant factory will create a TLS server which presents a certificate
     signed by our test CA, valid for the domains in `sanlist`
 
     Args:
-        factory (interfaces.IProtocolFactory): protocol factory to wrap
-        sanlist (iterable[bytes]): list of domains the cert should be valid for
+        factory: protocol factory to wrap
+        sanlist: list of domains the cert should be valid for
 
     Returns:
         interfaces.IProtocolFactory
@@ -552,7 +834,7 @@ def _wrap_server_factory_for_tls(factory, sanlist=None):
     )
 
 
-def _get_test_protocol_factory():
+def _get_test_protocol_factory() -> IProtocolFactory:
     """Get a protocol Factory which will build an HTTPChannel
 
     Returns:
@@ -566,6 +848,6 @@ def _get_test_protocol_factory():
     return server_factory
 
 
-def _log_request(request):
+def _log_request(request: str):
     """Implements Factory.log, which is expected by Request.finish"""
-    logger.info("Completed request %s", request)
+    logger.info(f"Completed request {request}")