diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 7bdc4acae7..59ab8fad35 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
import re
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple
from urllib.parse import urlparse
from urllib.request import ( # type: ignore[attr-defined]
getproxies_environment,
@@ -23,8 +24,17 @@ from urllib.request import ( # type: ignore[attr-defined]
from zope.interface import implementer
from twisted.internet import defer
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.endpoints import (
+ HostnameEndpoint,
+ UNIXClientEndpoint,
+ wrapClientTLS,
+)
+from twisted.internet.interfaces import (
+ IProtocol,
+ IProtocolFactory,
+ IReactorCore,
+ IStreamClientEndpoint,
+)
from twisted.python.failure import Failure
from twisted.web.client import (
URI,
@@ -36,8 +46,18 @@ from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
+from synapse.config.workers import (
+ InstanceLocationConfig,
+ InstanceTcpLocationConfig,
+ InstanceUnixLocationConfig,
+)
from synapse.http import redact_uri
-from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
+from synapse.http.connectproxyclient import (
+ BasicProxyCredentials,
+ HTTPConnectProxyEndpoint,
+ ProxyCredentials,
+)
+from synapse.logging.context import run_in_background
logger = logging.getLogger(__name__)
@@ -74,6 +94,14 @@ class ProxyAgent(_AgentBase):
use_proxy: Whether proxy settings should be discovered and used
from conventional environment variables.
+ federation_proxy_locations: An optional list of locations to proxy outbound federation
+ traffic through (only requests that use the `matrix-federation://` scheme
+ will be proxied).
+
+ federation_proxy_credentials: Required if `federation_proxy_locations` is set. The
+ credentials to use when proxying outbound federation traffic through another
+ worker.
+
Raises:
ValueError if use_proxy is set and the environment variables
contain an invalid proxy specification.
@@ -89,6 +117,8 @@ class ProxyAgent(_AgentBase):
bindAddress: Optional[bytes] = None,
pool: Optional[HTTPConnectionPool] = None,
use_proxy: bool = False,
+ federation_proxy_locations: Collection[InstanceLocationConfig] = (),
+ federation_proxy_credentials: Optional[ProxyCredentials] = None,
):
contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
@@ -127,6 +157,47 @@ class ProxyAgent(_AgentBase):
self._policy_for_https = contextFactory
self._reactor = reactor
+ self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None
+ self._federation_proxy_credentials: Optional[ProxyCredentials] = None
+ if federation_proxy_locations:
+ assert (
+ federation_proxy_credentials is not None
+ ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`"
+
+ endpoints: List[IStreamClientEndpoint] = []
+ for federation_proxy_location in federation_proxy_locations:
+ endpoint: IStreamClientEndpoint
+ if isinstance(federation_proxy_location, InstanceTcpLocationConfig):
+ endpoint = HostnameEndpoint(
+ self.proxy_reactor,
+ federation_proxy_location.host,
+ federation_proxy_location.port,
+ )
+ if federation_proxy_location.tls:
+ tls_connection_creator = (
+ self._policy_for_https.creatorForNetloc(
+ federation_proxy_location.host.encode("utf-8"),
+ federation_proxy_location.port,
+ )
+ )
+ endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+
+ elif isinstance(federation_proxy_location, InstanceUnixLocationConfig):
+ endpoint = UNIXClientEndpoint(
+ self.proxy_reactor, federation_proxy_location.path
+ )
+
+ else:
+ # It is supremely unlikely we ever hit this
+ raise SchemeNotSupported(
+ f"Unknown type of Endpoint requested, check {federation_proxy_location}"
+ )
+
+ endpoints.append(endpoint)
+
+ self._federation_proxy_endpoint = _RandomSampleEndpoints(endpoints)
+ self._federation_proxy_credentials = federation_proxy_credentials
+
def request(
self,
method: bytes,
@@ -214,6 +285,25 @@ class ProxyAgent(_AgentBase):
parsed_uri.port,
self.https_proxy_creds,
)
+ elif (
+ parsed_uri.scheme == b"matrix-federation"
+ and self._federation_proxy_endpoint
+ ):
+ assert (
+ self._federation_proxy_credentials is not None
+ ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`"
+
+ # Set a Proxy-Authorization header
+ if headers is None:
+ headers = Headers()
+ # We always need authentication for the outbound federation proxy
+ headers.addRawHeader(
+ b"Proxy-Authorization",
+ self._federation_proxy_credentials.as_proxy_authorization_value(),
+ )
+
+ endpoint = self._federation_proxy_endpoint
+ request_path = uri
else:
# not using a proxy
endpoint = HostnameEndpoint(
@@ -233,6 +323,11 @@ class ProxyAgent(_AgentBase):
endpoint = wrapClientTLS(tls_connection_creator, endpoint)
elif parsed_uri.scheme == b"http":
pass
+ elif (
+ parsed_uri.scheme == b"matrix-federation"
+ and self._federation_proxy_endpoint
+ ):
+ pass
else:
return defer.fail(
Failure(
@@ -334,6 +429,42 @@ def parse_proxy(
credentials = None
if url.username and url.password:
- credentials = ProxyCredentials(b"".join([url.username, b":", url.password]))
+ credentials = BasicProxyCredentials(
+ b"".join([url.username, b":", url.password])
+ )
return url.scheme, url.hostname, url.port or default_port, credentials
+
+
+@implementer(IStreamClientEndpoint)
+class _RandomSampleEndpoints:
+ """An endpoint that randomly iterates through a given list of endpoints at
+ each connection attempt.
+ """
+
+ def __init__(
+ self,
+ endpoints: Sequence[IStreamClientEndpoint],
+ ) -> None:
+ assert endpoints
+ self._endpoints = endpoints
+
+ def __repr__(self) -> str:
+ return f"<_RandomSampleEndpoints endpoints={self._endpoints}>"
+
+ def connect(
+ self, protocol_factory: IProtocolFactory
+ ) -> "defer.Deferred[IProtocol]":
+ """Implements IStreamClientEndpoint interface"""
+
+ return run_in_background(self._do_connect, protocol_factory)
+
+ async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
+ failures: List[Failure] = []
+ for endpoint in random.sample(self._endpoints, k=len(self._endpoints)):
+ try:
+ return await endpoint.connect(protocol_factory)
+ except Exception:
+ failures.append(Failure())
+
+ failures.pop().raiseException()
|