summary refs log tree commit diff
path: root/synapse/http/proxyagent.py
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2023-07-05 18:53:55 -0500
committerGitHub <noreply@github.com>2023-07-05 18:53:55 -0500
commitb07b14b494ae1dd564b4c44f844c9a9545b3d08a (patch)
tree3fadaf825910c72bb9fc1b4610610bbe6721eb1f /synapse/http/proxyagent.py
parentRemove support for Python 3.7 (#15851) (diff)
downloadsynapse-b07b14b494ae1dd564b4c44f844c9a9545b3d08a.tar.xz
Federation outbound proxy (#15773)
Allow configuring the set of workers to proxy outbound federation traffic through (`outbound_federation_restricted_to`).

This is useful when you have a worker setup with `federation_sender` instances responsible for sending outbound federation requests and want to make sure *all* outbound federation traffic goes through those instances. Before this change, the generic workers would still contact federation themselves for things like profile lookups, backfill, etc. This PR allows you to set more strict access controls/firewall for all workers and only allow the `federation_sender`'s to contact the outside world.

The original code is from @erikjohnston's branches which I've gotten in-shape to merge.
Diffstat (limited to 'synapse/http/proxyagent.py')
-rw-r--r--synapse/http/proxyagent.py79
1 files changed, 77 insertions, 2 deletions
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 7bdc4acae7..1fa3adbef2 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -12,8 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 import re
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple
 from urllib.parse import urlparse
 from urllib.request import (  # type: ignore[attr-defined]
     getproxies_environment,
@@ -24,7 +25,12 @@ from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+    IProtocol,
+    IProtocolFactory,
+    IReactorCore,
+    IStreamClientEndpoint,
+)
 from twisted.python.failure import Failure
 from twisted.web.client import (
     URI,
@@ -36,8 +42,10 @@ from twisted.web.error import SchemeNotSupported
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
 
+from synapse.config.workers import InstanceLocationConfig
 from synapse.http import redact_uri
 from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
+from synapse.logging.context import run_in_background
 
 logger = logging.getLogger(__name__)
 
@@ -74,6 +82,10 @@ class ProxyAgent(_AgentBase):
         use_proxy: Whether proxy settings should be discovered and used
             from conventional environment variables.
 
+        federation_proxies: An optional list of locations to proxy outbound federation
+            traffic through (only requests that use the `matrix-federation://` scheme
+            will be proxied).
+
     Raises:
         ValueError if use_proxy is set and the environment variables
             contain an invalid proxy specification.
@@ -89,6 +101,7 @@ class ProxyAgent(_AgentBase):
         bindAddress: Optional[bytes] = None,
         pool: Optional[HTTPConnectionPool] = None,
         use_proxy: bool = False,
+        federation_proxies: Collection[InstanceLocationConfig] = (),
     ):
         contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
 
@@ -127,6 +140,27 @@ class ProxyAgent(_AgentBase):
         self._policy_for_https = contextFactory
         self._reactor = reactor
 
+        self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None
+        if federation_proxies:
+            endpoints = []
+            for federation_proxy in federation_proxies:
+                endpoint = HostnameEndpoint(
+                    self.proxy_reactor,
+                    federation_proxy.host,
+                    federation_proxy.port,
+                )
+
+                if federation_proxy.tls:
+                    tls_connection_creator = self._policy_for_https.creatorForNetloc(
+                        federation_proxy.host,
+                        federation_proxy.port,
+                    )
+                    endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+
+                endpoints.append(endpoint)
+
+            self._federation_proxy_endpoint = _ProxyEndpoints(endpoints)
+
     def request(
         self,
         method: bytes,
@@ -214,6 +248,14 @@ class ProxyAgent(_AgentBase):
                 parsed_uri.port,
                 self.https_proxy_creds,
             )
+        elif (
+            parsed_uri.scheme == b"matrix-federation"
+            and self._federation_proxy_endpoint
+        ):
+            # Cache *all* connections under the same key, since we are only
+            # connecting to a single destination, the proxy:
+            endpoint = self._federation_proxy_endpoint
+            request_path = uri
         else:
             # not using a proxy
             endpoint = HostnameEndpoint(
@@ -233,6 +275,11 @@ class ProxyAgent(_AgentBase):
             endpoint = wrapClientTLS(tls_connection_creator, endpoint)
         elif parsed_uri.scheme == b"http":
             pass
+        elif (
+            parsed_uri.scheme == b"matrix-federation"
+            and self._federation_proxy_endpoint
+        ):
+            pass
         else:
             return defer.fail(
                 Failure(
@@ -337,3 +384,31 @@ def parse_proxy(
         credentials = ProxyCredentials(b"".join([url.username, b":", url.password]))
 
     return url.scheme, url.hostname, url.port or default_port, credentials
+
+
+@implementer(IStreamClientEndpoint)
+class _ProxyEndpoints:
+    """An endpoint that randomly iterates through a given list of endpoints at
+    each connection attempt.
+    """
+
+    def __init__(self, endpoints: Sequence[IStreamClientEndpoint]) -> None:
+        assert endpoints
+        self._endpoints = endpoints
+
+    def connect(
+        self, protocol_factory: IProtocolFactory
+    ) -> "defer.Deferred[IProtocol]":
+        """Implements IStreamClientEndpoint interface"""
+
+        return run_in_background(self._do_connect, protocol_factory)
+
+    async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
+        failures: List[Failure] = []
+        for endpoint in random.sample(self._endpoints, k=len(self._endpoints)):
+            try:
+                return await endpoint.connect(protocol_factory)
+            except Exception:
+                failures.append(Failure())
+
+        failures.pop().raiseException()