summary refs log tree commit diff
path: root/synapse/http/proxyagent.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/proxyagent.py')
-rw-r--r--synapse/http/proxyagent.py173
1 files changed, 154 insertions, 19 deletions
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index b2a50c9105..59ab8fad35 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -12,8 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 import re
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple
 from urllib.parse import urlparse
 from urllib.request import (  # type: ignore[attr-defined]
     getproxies_environment,
@@ -23,8 +24,17 @@ from urllib.request import (  # type: ignore[attr-defined]
 from zope.interface import implementer
 
 from twisted.internet import defer
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.endpoints import (
+    HostnameEndpoint,
+    UNIXClientEndpoint,
+    wrapClientTLS,
+)
+from twisted.internet.interfaces import (
+    IProtocol,
+    IProtocolFactory,
+    IReactorCore,
+    IStreamClientEndpoint,
+)
 from twisted.python.failure import Failure
 from twisted.web.client import (
     URI,
@@ -34,10 +44,20 @@ from twisted.web.client import (
 )
 from twisted.web.error import SchemeNotSupported
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS
+from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
 
-from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
-from synapse.types import ISynapseReactor
+from synapse.config.workers import (
+    InstanceLocationConfig,
+    InstanceTcpLocationConfig,
+    InstanceUnixLocationConfig,
+)
+from synapse.http import redact_uri
+from synapse.http.connectproxyclient import (
+    BasicProxyCredentials,
+    HTTPConnectProxyEndpoint,
+    ProxyCredentials,
+)
+from synapse.logging.context import run_in_background
 
 logger = logging.getLogger(__name__)
 
@@ -53,7 +73,7 @@ class ProxyAgent(_AgentBase):
             connections.
 
         proxy_reactor: twisted reactor to use for connections to the proxy server
-                       reactor might have some blacklisting applied (i.e. for DNS queries),
+                       reactor might have some blocking applied (i.e. for DNS queries),
                        but we need unblocked access to the proxy.
 
         contextFactory: A factory for TLS contexts, to control the
@@ -74,6 +94,14 @@ class ProxyAgent(_AgentBase):
         use_proxy: Whether proxy settings should be discovered and used
             from conventional environment variables.
 
+        federation_proxy_locations: An optional list of locations to proxy outbound federation
+            traffic through (only requests that use the `matrix-federation://` scheme
+            will be proxied).
+
+        federation_proxy_credentials: Required if `federation_proxy_locations` is set. The
+            credentials to use when proxying outbound federation traffic through another
+            worker.
+
     Raises:
         ValueError if use_proxy is set and the environment variables
             contain an invalid proxy specification.
@@ -83,12 +111,14 @@ class ProxyAgent(_AgentBase):
     def __init__(
         self,
         reactor: IReactorCore,
-        proxy_reactor: Optional[ISynapseReactor] = None,
+        proxy_reactor: Optional[IReactorCore] = None,
         contextFactory: Optional[IPolicyForHTTPS] = None,
         connectTimeout: Optional[float] = None,
         bindAddress: Optional[bytes] = None,
         pool: Optional[HTTPConnectionPool] = None,
         use_proxy: bool = False,
+        federation_proxy_locations: Collection[InstanceLocationConfig] = (),
+        federation_proxy_credentials: Optional[ProxyCredentials] = None,
     ):
         contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
 
@@ -127,13 +157,54 @@ class ProxyAgent(_AgentBase):
         self._policy_for_https = contextFactory
         self._reactor = reactor
 
+        self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None
+        self._federation_proxy_credentials: Optional[ProxyCredentials] = None
+        if federation_proxy_locations:
+            assert (
+                federation_proxy_credentials is not None
+            ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`"
+
+            endpoints: List[IStreamClientEndpoint] = []
+            for federation_proxy_location in federation_proxy_locations:
+                endpoint: IStreamClientEndpoint
+                if isinstance(federation_proxy_location, InstanceTcpLocationConfig):
+                    endpoint = HostnameEndpoint(
+                        self.proxy_reactor,
+                        federation_proxy_location.host,
+                        federation_proxy_location.port,
+                    )
+                    if federation_proxy_location.tls:
+                        tls_connection_creator = (
+                            self._policy_for_https.creatorForNetloc(
+                                federation_proxy_location.host.encode("utf-8"),
+                                federation_proxy_location.port,
+                            )
+                        )
+                        endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+
+                elif isinstance(federation_proxy_location, InstanceUnixLocationConfig):
+                    endpoint = UNIXClientEndpoint(
+                        self.proxy_reactor, federation_proxy_location.path
+                    )
+
+                else:
+                    # It is supremely unlikely we ever hit this
+                    raise SchemeNotSupported(
+                        f"Unknown type of Endpoint requested, check {federation_proxy_location}"
+                    )
+
+                endpoints.append(endpoint)
+
+            self._federation_proxy_endpoint = _RandomSampleEndpoints(endpoints)
+            self._federation_proxy_credentials = federation_proxy_credentials
+
     def request(
         self,
         method: bytes,
         uri: bytes,
         headers: Optional[Headers] = None,
         bodyProducer: Optional[IBodyProducer] = None,
-    ) -> defer.Deferred:
+    ) -> "defer.Deferred[IResponse]":
         """
         Issue a request to the server indicated by the given uri.
 
@@ -156,17 +227,17 @@ class ProxyAgent(_AgentBase):
                 a file upload). Or, None if the request is to have no body.
 
         Returns:
-            Deferred[IResponse]: completes when the header of the response has
-                 been received (regardless of the response status code).
+            A deferred which completes when the header of the response has
+            been received (regardless of the response status code).
 
-                 Can fail with:
-                    SchemeNotSupported: if the uri is not http or https
+            Can fail with:
+                SchemeNotSupported: if the uri is not http or https
 
-                    twisted.internet.error.TimeoutError if the server we are connecting
-                        to (proxy or destination) does not accept a connection before
-                        connectTimeout.
+                twisted.internet.error.TimeoutError if the server we are connecting
+                    to (proxy or destination) does not accept a connection before
+                    connectTimeout.
 
-                    ... other things too.
+                ... other things too.
         """
         uri = uri.strip()
         if not _VALID_URI.match(uri):
@@ -214,13 +285,36 @@ class ProxyAgent(_AgentBase):
                 parsed_uri.port,
                 self.https_proxy_creds,
             )
+        elif (
+            parsed_uri.scheme == b"matrix-federation"
+            and self._federation_proxy_endpoint
+        ):
+            assert (
+                self._federation_proxy_credentials is not None
+            ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`"
+
+            # Set a Proxy-Authorization header
+            if headers is None:
+                headers = Headers()
+            # We always need authentication for the outbound federation proxy
+            headers.addRawHeader(
+                b"Proxy-Authorization",
+                self._federation_proxy_credentials.as_proxy_authorization_value(),
+            )
+
+            endpoint = self._federation_proxy_endpoint
+            request_path = uri
         else:
             # not using a proxy
             endpoint = HostnameEndpoint(
                 self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
             )
 
-        logger.debug("Requesting %s via %s", uri, endpoint)
+        logger.debug(
+            "Requesting %s via %s",
+            redact_uri(uri.decode("ascii", errors="replace")),
+            endpoint,
+        )
 
         if parsed_uri.scheme == b"https":
             tls_connection_creator = self._policy_for_https.creatorForNetloc(
@@ -229,6 +323,11 @@ class ProxyAgent(_AgentBase):
             endpoint = wrapClientTLS(tls_connection_creator, endpoint)
         elif parsed_uri.scheme == b"http":
             pass
+        elif (
+            parsed_uri.scheme == b"matrix-federation"
+            and self._federation_proxy_endpoint
+        ):
+            pass
         else:
             return defer.fail(
                 Failure(
@@ -330,6 +429,42 @@ def parse_proxy(
 
     credentials = None
     if url.username and url.password:
-        credentials = ProxyCredentials(b"".join([url.username, b":", url.password]))
+        credentials = BasicProxyCredentials(
+            b"".join([url.username, b":", url.password])
+        )
 
     return url.scheme, url.hostname, url.port or default_port, credentials
+
+
+@implementer(IStreamClientEndpoint)
+class _RandomSampleEndpoints:
+    """An endpoint that randomly iterates through a given list of endpoints at
+    each connection attempt.
+    """
+
+    def __init__(
+        self,
+        endpoints: Sequence[IStreamClientEndpoint],
+    ) -> None:
+        assert endpoints
+        self._endpoints = endpoints
+
+    def __repr__(self) -> str:
+        return f"<_RandomSampleEndpoints endpoints={self._endpoints}>"
+
+    def connect(
+        self, protocol_factory: IProtocolFactory
+    ) -> "defer.Deferred[IProtocol]":
+        """Implements IStreamClientEndpoint interface"""
+
+        return run_in_background(self._do_connect, protocol_factory)
+
+    async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
+        failures: List[Failure] = []
+        for endpoint in random.sample(self._endpoints, k=len(self._endpoints)):
+            try:
+                return await endpoint.connect(protocol_factory)
+            except Exception:
+                failures.append(Failure())
+
+        failures.pop().raiseException()