summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2023-03-01 18:37:22 +0000
committerDavid Robertson <davidr@element.io>2023-03-02 16:21:36 +0000
commite41a90e89049776bf1d51c48f3b9a5612fc7810c (patch)
tree47bccda030373ce0a937ba713afce5d1685b06a9
parentWIP: listen for proxy requests (diff)
downloadsynapse-e41a90e89049776bf1d51c48f3b9a5612fc7810c.tar.xz
WIP: make proxy requests
-rw-r--r--synapse/http/federation/matrix_federation_agent.py227
-rw-r--r--synapse/http/matrixfederationclient.py37
2 files changed, 197 insertions, 67 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 0359231e7d..60d1fcb404 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -11,9 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
 import urllib.parse
-from typing import Any, Generator, List, Optional
+from typing import Any, Generator, List, Mapping, Optional
 from urllib.request import (  # type: ignore[attr-defined]
     getproxies_environment,
     proxy_bypass_environment,
@@ -30,10 +31,11 @@ from twisted.internet.interfaces import (
     IReactorCore,
     IStreamClientEndpoint,
 )
-from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.client import URI, Agent, HTTPConnectionPool, _AgentBase
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse
 
+from synapse.config.workers import InstanceLocationConfig
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http import proxyagent
 from synapse.http.client import BlacklistingAgentWrapper, BlacklistingReactorWrapper
@@ -49,10 +51,11 @@ logger = logging.getLogger(__name__)
 
 
 @implementer(IAgent)
-class MatrixFederationAgent:
-    """An Agent-like thing which provides a `request` method which correctly
-    handles resolving matrix server names when using matrix://. Handles standard
-    https URIs as normal.
+class BaseMatrixFederationAgent(abc.ABC):
+    """An Agent-like thing which provides a `request` method that accepts matrix://
+     URIs.
+
+    Handles standard https URIs as normal.
 
     Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
 
@@ -82,19 +85,16 @@ class MatrixFederationAgent:
             default implementation.
     """
 
+    _agent: IAgent
+
     def __init__(
         self,
         reactor: ISynapseReactor,
-        tls_client_options_factory: Optional[FederationPolicyForHTTPS],
         user_agent: bytes,
         ip_whitelist: IPSet,
         ip_blacklist: IPSet,
         _srv_resolver: Optional[SrvResolver] = None,
-        _well_known_resolver: Optional[WellKnownResolver] = None,
     ):
-        # proxy_reactor is not blacklisted
-        proxy_reactor = reactor
-
         # We need to use a DNS resolver which filters out blacklisted IP
         # addresses, to prevent DNS rebinding.
         reactor = BlacklistingReactorWrapper(reactor, ip_whitelist, ip_blacklist)
@@ -105,35 +105,8 @@ class MatrixFederationAgent:
         self._pool.maxPersistentPerHost = 5
         self._pool.cachedConnectionTimeout = 2 * 60
 
-        self._agent = Agent.usingEndpointFactory(
-            reactor,
-            MatrixHostnameEndpointFactory(
-                reactor,
-                proxy_reactor,
-                tls_client_options_factory,
-                _srv_resolver,
-            ),
-            pool=self._pool,
-        )
         self.user_agent = user_agent
-
-        if _well_known_resolver is None:
-            _well_known_resolver = WellKnownResolver(
-                reactor,
-                agent=BlacklistingAgentWrapper(
-                    ProxyAgent(
-                        reactor,
-                        proxy_reactor,
-                        pool=self._pool,
-                        contextFactory=tls_client_options_factory,
-                        use_proxy=True,
-                    ),
-                    ip_blacklist=ip_blacklist,
-                ),
-                user_agent=self.user_agent,
-            )
-
-        self._well_known_resolver = _well_known_resolver
+        self._reactor = reactor
 
     @defer.inlineCallbacks
     def request(
@@ -164,6 +137,86 @@ class MatrixFederationAgent:
         # explicit port.
         parsed_uri = urllib.parse.urlparse(uri)
 
+        parsed_uri = yield self.postprocess_uri(parsed_uri)
+
+        # We need to make sure the host header is set to the netloc of the
+        # server and that a user-agent is provided.
+        if headers is None:
+            request_headers = Headers()
+        else:
+            request_headers = headers.copy()
+
+        if not request_headers.hasHeader(b"host"):
+            request_headers.addRawHeader(b"host", parsed_uri.netloc)
+        if not request_headers.hasHeader(b"user-agent"):
+            request_headers.addRawHeader(b"user-agent", self.user_agent)
+
+        res = yield make_deferred_yieldable(
+            self._agent.request(method, uri, request_headers, bodyProducer)
+        )
+
+        return res
+
+    @abc.abstractmethod
+    @defer.inlineCallbacks
+    def postprocess_uri(
+        self, parsed_uri: "urllib.parse.ParseResultBytes"
+    ) -> Generator["defer.Deferred", Any, "urllib.parse.ParseResultBytes"]:
+        ...
+
+
+class MatrixFederationAgent(BaseMatrixFederationAgent):
+    """A federation agent that resolves server delegation by itself, using
+    SRV or .well-known lookups."""
+
+    def __init__(
+        self,
+        reactor: ISynapseReactor,
+        tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+        user_agent: bytes,
+        ip_whitelist: IPSet,
+        ip_blacklist: IPSet,
+        _srv_resolver: Optional[SrvResolver] = None,
+        _well_known_resolver: Optional[WellKnownResolver] = None,
+    ):
+        super().__init__(reactor, user_agent, ip_whitelist, ip_blacklist)
+
+        # proxy_reactor is not blacklisted
+        proxy_reactor = reactor
+
+        self._agent = Agent.usingEndpointFactory(
+            self._reactor,
+            MatrixHostnameEndpointFactory(
+                self._reactor,
+                proxy_reactor,
+                tls_client_options_factory,
+                _srv_resolver,
+            ),
+            pool=self._pool,
+        )
+
+        if _well_known_resolver is None:
+            _well_known_resolver = WellKnownResolver(
+                reactor,
+                agent=BlacklistingAgentWrapper(
+                    ProxyAgent(
+                        self._reactor,
+                        proxy_reactor,
+                        pool=self._pool,
+                        contextFactory=tls_client_options_factory,
+                        use_proxy=True,
+                    ),
+                    ip_blacklist=ip_blacklist,
+                ),
+                user_agent=self.user_agent,
+            )
+
+        self._well_known_resolver = _well_known_resolver
+
+    @defer.inlineCallbacks
+    def postprocess_uri(
+        self, parsed_uri: "urllib.parse.ParseResultBytes"
+    ) -> Generator[defer.Deferred, Any, "urllib.parse.ParseResultBytes"]:
         # There must be a valid hostname.
         assert parsed_uri.hostname
 
@@ -182,7 +235,6 @@ class MatrixFederationAgent:
                 self._well_known_resolver.get_well_known(parsed_uri.hostname)
             )
             delegated_server = well_known_result.delegated_server
-
         if delegated_server:
             # Ok, the server has delegated matrix traffic to somewhere else, so
             # lets rewrite the URL to replace the server with the delegated
@@ -198,24 +250,7 @@ class MatrixFederationAgent:
                 )
             )
             parsed_uri = urllib.parse.urlparse(uri)
-
-        # We need to make sure the host header is set to the netloc of the
-        # server and that a user-agent is provided.
-        if headers is None:
-            request_headers = Headers()
-        else:
-            request_headers = headers.copy()
-
-        if not request_headers.hasHeader(b"host"):
-            request_headers.addRawHeader(b"host", parsed_uri.netloc)
-        if not request_headers.hasHeader(b"user-agent"):
-            request_headers.addRawHeader(b"user-agent", self.user_agent)
-
-        res = yield make_deferred_yieldable(
-            self._agent.request(method, uri, request_headers, bodyProducer)
-        )
-
-        return res
+        return parsed_uri
 
 
 @implementer(IAgentEndpointFactory)
@@ -430,3 +465,79 @@ def _is_ip_literal(host: bytes) -> bool:
         return True
     except AddrFormatError:
         return False
+
+
+@implementer(IAgent)
+class InternalProxyMatrixFederationAgentInner(_AgentBase):
+    def __init__(
+        self,
+        reactor: IReactorCore,
+        endpointFactory: IAgentEndpointFactory,
+        pool: HTTPConnectionPool,
+    ):
+        _AgentBase.__init__(self, reactor, pool)
+        self._endpointFactory = endpointFactory
+
+    def request(
+        self,
+        method: bytes,
+        uri: bytes,
+        headers: Optional[Headers] = None,
+        bodyProducer: Optional[IBodyProducer] = None,
+    ) -> "defer.Deferred[IResponse]":
+        # Cache *all* connections under the same key, since we are only
+        # connecting to a single destination, the proxy:
+        # TODO make second entry an endpoint
+        key = ("http-proxy", None)
+        parsed_uri = URI.fromBytes(uri)
+        return self._requestWithEndpoint(
+            key,
+            self._endpointFactory.endpointForURI(parsed_uri),
+            method,
+            parsed_uri,
+            headers,
+            bodyProducer,
+            uri,
+        )
+
+
+class InternalProxyMatrixFederationAgent(BaseMatrixFederationAgent):
+    def __init__(
+        self,
+        reactor: ISynapseReactor,
+        user_agent: bytes,
+        ip_whitelist: IPSet,
+        ip_blacklist: IPSet,
+        proxy_workers: Mapping[str, InstanceLocationConfig],
+    ):
+        super().__init__(reactor, user_agent, ip_whitelist, ip_blacklist)
+        self._agent = InternalProxyMatrixFederationAgentInner(
+            self._reactor,
+            InternalProxyMatrixHostnameEndpointFactory(reactor, proxy_workers),
+            pool=self._pool,
+        )
+
+    @defer.inlineCallbacks
+    def postprocess_uri(
+        self, parsed_uri: "urllib.parse.ParseResultBytes"
+    ) -> Generator[defer.Deferred, Any, urllib.parse.ParseResultBytes]:
+        yield None
+        return parsed_uri
+
+
+@implementer(IAgentEndpointFactory)
+class InternalProxyMatrixHostnameEndpointFactory:
+    def __init__(
+        self, reactor: IReactorCore, proxy_workers: Mapping[str, InstanceLocationConfig]
+    ):
+        self._reactor = reactor
+        self._proxy_workers = proxy_workers
+
+    def endpointForURI(self, uri: URI) -> IStreamClientEndpoint:
+        # key = uri.toBytes()
+        # TODO chose instance based on key
+        proxy_location = next(iter(self._proxy_workers.values()))
+        # TODO does this need wrapping with wrapClientTLS?
+        logger.warning(f"DMR: make endpoint for {uri}, {self._proxy_workers=}")
+        rv = HostnameEndpoint(self._reactor, proxy_location.host, proxy_location.port)
+        return rv
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3302d4e48a..f9d811fca1 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -49,7 +49,7 @@ from twisted.internet.interfaces import IReactorTime
 from twisted.internet.task import Cooperator
 from twisted.web.client import ResponseFailed
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IBodyProducer, IResponse
+from twisted.web.iweb import IAgent, IBodyProducer, IResponse
 
 import synapse.metrics
 import synapse.util.retryutils
@@ -70,7 +70,10 @@ from synapse.http.client import (
     encode_query_args,
     read_body_with_max_size,
 )
-from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.federation.matrix_federation_agent import (
+    InternalProxyMatrixFederationAgent,
+    MatrixFederationAgent,
+)
 from synapse.http.types import QueryParams
 from synapse.logging import opentracing
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -348,13 +351,29 @@ class MatrixFederationHttpClient:
         if hs.config.server.user_agent_suffix:
             user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix)
 
-        federation_agent = MatrixFederationAgent(
-            self.reactor,
-            tls_client_options_factory,
-            user_agent.encode("ascii"),
-            hs.config.server.federation_ip_range_whitelist,
-            hs.config.server.federation_ip_range_blacklist,
-        )
+        if (
+            hs.config.worker.outbound_fed_restricted_to
+            and hs.get_instance_name()
+            not in hs.config.worker.outbound_fed_restricted_to
+        ):
+            logger.warning("DMR: Using InternalProxyMatrixFederationAgent")
+            # We must
+            federation_agent: IAgent = InternalProxyMatrixFederationAgent(
+                self.reactor,
+                user_agent.encode("ascii"),
+                hs.config.server.federation_ip_range_whitelist,
+                hs.config.server.federation_ip_range_blacklist,
+                hs.config.worker.outbound_fed_restricted_to,
+            )
+        else:
+            logger.warning("DMR: Using MatrixFederationAgent")
+            federation_agent = MatrixFederationAgent(
+                self.reactor,
+                tls_client_options_factory,
+                user_agent.encode("ascii"),
+                hs.config.server.federation_ip_range_whitelist,
+                hs.config.server.federation_ip_range_blacklist,
+            )
 
         # Use a BlacklistingAgentWrapper to prevent circumventing the IP
         # blacklist via IP literals in server names