diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 7bdc4acae7..1fa3adbef2 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
import re
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple
from urllib.parse import urlparse
from urllib.request import ( # type: ignore[attr-defined]
getproxies_environment,
@@ -24,7 +25,12 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+ IProtocol,
+ IProtocolFactory,
+ IReactorCore,
+ IStreamClientEndpoint,
+)
from twisted.python.failure import Failure
from twisted.web.client import (
URI,
@@ -36,8 +42,10 @@ from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
+from synapse.config.workers import InstanceLocationConfig
from synapse.http import redact_uri
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
+from synapse.logging.context import run_in_background
logger = logging.getLogger(__name__)
@@ -74,6 +82,10 @@ class ProxyAgent(_AgentBase):
use_proxy: Whether proxy settings should be discovered and used
from conventional environment variables.
+ federation_proxies: An optional list of locations to proxy outbound federation
+ traffic through (only requests that use the `matrix-federation://` scheme
+ will be proxied).
+
Raises:
ValueError if use_proxy is set and the environment variables
contain an invalid proxy specification.
@@ -89,6 +101,7 @@ class ProxyAgent(_AgentBase):
bindAddress: Optional[bytes] = None,
pool: Optional[HTTPConnectionPool] = None,
use_proxy: bool = False,
+ federation_proxies: Collection[InstanceLocationConfig] = (),
):
contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
@@ -127,6 +140,27 @@ class ProxyAgent(_AgentBase):
self._policy_for_https = contextFactory
self._reactor = reactor
+ self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None
+ if federation_proxies:
+ endpoints = []
+ for federation_proxy in federation_proxies:
+ endpoint = HostnameEndpoint(
+ self.proxy_reactor,
+ federation_proxy.host,
+ federation_proxy.port,
+ )
+
+ if federation_proxy.tls:
+ tls_connection_creator = self._policy_for_https.creatorForNetloc(
+ federation_proxy.host,
+ federation_proxy.port,
+ )
+ endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+
+ endpoints.append(endpoint)
+
+ self._federation_proxy_endpoint = _ProxyEndpoints(endpoints)
+
def request(
self,
method: bytes,
@@ -214,6 +248,14 @@ class ProxyAgent(_AgentBase):
parsed_uri.port,
self.https_proxy_creds,
)
+ elif (
+ parsed_uri.scheme == b"matrix-federation"
+ and self._federation_proxy_endpoint
+ ):
+ # Cache *all* connections under the same key, since we are only
+ # connecting to a single destination, the proxy:
+ endpoint = self._federation_proxy_endpoint
+ request_path = uri
else:
# not using a proxy
endpoint = HostnameEndpoint(
@@ -233,6 +275,11 @@ class ProxyAgent(_AgentBase):
endpoint = wrapClientTLS(tls_connection_creator, endpoint)
elif parsed_uri.scheme == b"http":
pass
+ elif (
+ parsed_uri.scheme == b"matrix-federation"
+ and self._federation_proxy_endpoint
+ ):
+ pass
else:
return defer.fail(
Failure(
@@ -337,3 +384,31 @@ def parse_proxy(
credentials = ProxyCredentials(b"".join([url.username, b":", url.password]))
return url.scheme, url.hostname, url.port or default_port, credentials
+
+
+@implementer(IStreamClientEndpoint)
+class _ProxyEndpoints:
+ """An endpoint that randomly iterates through a given list of endpoints at
+ each connection attempt.
+ """
+
+ def __init__(self, endpoints: Sequence[IStreamClientEndpoint]) -> None:
+ assert endpoints
+ self._endpoints = endpoints
+
+ def connect(
+ self, protocol_factory: IProtocolFactory
+ ) -> "defer.Deferred[IProtocol]":
+ """Implements IStreamClientEndpoint interface"""
+
+ return run_in_background(self._do_connect, protocol_factory)
+
+ async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
+ failures: List[Failure] = []
+ for endpoint in random.sample(self._endpoints, k=len(self._endpoints)):
+ try:
+ return await endpoint.connect(protocol_factory)
+ except Exception:
+ failures.append(Failure())
+
+ failures.pop().raiseException()
|