diff --git a/changelog.d/15773.feature b/changelog.d/15773.feature
new file mode 100644
index 0000000000..0d77fae2dc
--- /dev/null
+++ b/changelog.d/15773.feature
@@ -0,0 +1 @@
+Allow configuring the set of workers to proxy outbound federation traffic through via `outbound_federation_restricted_to`.
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 26d7c7900c..89a92c4682 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -3930,13 +3930,14 @@ federation_sender_instances:
---
### `instance_map`
-When using workers this should be a map from [`worker_name`](#worker_name) to the
-HTTP replication listener of the worker, if configured, and to the main process.
-Each worker declared under [`stream_writers`](../../workers.md#stream-writers) needs
-a HTTP replication listener, and that listener should be included in the `instance_map`.
-The main process also needs an entry on the `instance_map`, and it should be listed under
-`main` **if even one other worker exists**. Ensure the port matches with what is declared
-inside the `listener` block for a `replication` listener.
+When using workers this should be a map from [`worker_name`](#worker_name) to the HTTP
+replication listener of the worker, if configured, and to the main process. Each worker
+declared under [`stream_writers`](../../workers.md#stream-writers) and
+[`outbound_federation_restricted_to`](#outbound_federation_restricted_to) needs a HTTP replication listener, and that
+listener should be included in the `instance_map`. The main process also needs an entry
+on the `instance_map`, and it should be listed under `main` **if even one other worker
+exists**. Ensure the port matches with what is declared inside the `listener` block for
+a `replication` listener.
Example configuration:
@@ -3966,6 +3967,22 @@ stream_writers:
typing: worker1
```
---
+### `outbound_federation_restricted_to`
+
+When using workers, you can restrict outbound federation traffic to only go through a
+specific subset of workers. Any worker specified here must also be in the
+[`instance_map`](#instance_map).
+
+```yaml
+outbound_federation_restricted_to:
+ - federation_sender1
+ - federation_sender2
+```
+
+Also see the [worker
+documentation](../../workers.md#restrict-outbound-federation-traffic-to-a-specific-set-of-workers)
+for more info.
+---
### `run_background_tasks_on`
The [worker](../../workers.md#background-tasks) that is used to run
diff --git a/docs/workers.md b/docs/workers.md
index 735128762a..303e0f0e7a 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -531,6 +531,26 @@ the stream writer for the `presence` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
+#### Restrict outbound federation traffic to a specific set of workers
+
+The `outbound_federation_restricted_to` configuration is useful to make sure outbound
+federation traffic only goes through a specified subset of workers. This allows you to
+set more strict access controls (like a firewall) for all workers and only allow the
+`federation_sender`'s to contact the outside world.
+
+```yaml
+instance_map:
+ main:
+ host: localhost
+ port: 8030
+ federation_sender1:
+ host: localhost
+ port: 8034
+
+outbound_federation_restricted_to:
+ - federation_sender1
+```
+
#### Background tasks
There is also support for moving background tasks to a separate
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 936b1b0430..938ab40f27 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -386,6 +386,7 @@ def listen_unix(
def listen_http(
+ hs: "HomeServer",
listener_config: ListenerConfig,
root_resource: Resource,
version_string: str,
@@ -406,6 +407,7 @@ def listen_http(
version_string,
max_request_body_size=max_request_body_size,
reactor=reactor,
+ federation_agent=hs.get_federation_http_client().agent,
)
if isinstance(listener_config, TCPListenerConfig):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 7406c3948c..dc79efcc14 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -221,6 +221,7 @@ class GenericWorkerServer(HomeServer):
root_resource = create_resource_tree(resources, OptionsResource())
_base.listen_http(
+ self,
listener_config,
root_resource,
self.version_string,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 84236ac299..f188c7265a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -139,6 +139,7 @@ class SynapseHomeServer(HomeServer):
root_resource = OptionsResource()
ports = listen_http(
+ self,
listener_config,
create_resource_tree(resources, root_resource),
self.version_string,
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 38e13dd7b5..0b9789160c 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -15,7 +15,7 @@
import argparse
import logging
-from typing import Any, Dict, List, Union
+from typing import Any, Dict, List, Optional, Union
import attr
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
@@ -148,6 +148,27 @@ class WriterLocations:
)
+@attr.s(auto_attribs=True)
+class OutboundFederationRestrictedTo:
+ """Whether we limit outbound federation to a certain set of instances.
+
+ Attributes:
+ instances: optional list of instances that can make outbound federation
+ requests. If None then all instances can make federation requests.
+ locations: list of instance locations to connect to proxy via.
+ """
+
+ instances: Optional[List[str]]
+ locations: List[InstanceLocationConfig] = attr.Factory(list)
+
+ def __contains__(self, instance: str) -> bool:
+ # It feels a bit dirty to return `True` if `instances` is `None`, but it makes
+ # sense in downstream usage in the sense that if
+ # `outbound_federation_restricted_to` is not configured, then any instance can
+ # talk to federation (no restrictions so always return `True`).
+ return self.instances is None or instance in self.instances
+
+
class WorkerConfig(Config):
"""The workers are processes run separately to the main synapse process.
They have their own pid_file and listener configuration. They use the
@@ -357,6 +378,23 @@ class WorkerConfig(Config):
new_option_name="update_user_directory_from_worker",
)
+ outbound_federation_restricted_to = config.get(
+ "outbound_federation_restricted_to", None
+ )
+ self.outbound_federation_restricted_to = OutboundFederationRestrictedTo(
+ outbound_federation_restricted_to
+ )
+ if outbound_federation_restricted_to:
+ for instance in outbound_federation_restricted_to:
+ if instance not in self.instance_map:
+ raise ConfigError(
+ "Instance %r is configured in 'outbound_federation_restricted_to' but does not appear in `instance_map` config."
+ % (instance,)
+ )
+ self.outbound_federation_restricted_to.locations.append(
+ self.instance_map[instance]
+ )
+
def _should_this_worker_perform_duty(
self,
config: Dict[str, Any],
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 09ea93e10d..ca2cdbc6e2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -1037,7 +1037,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
- # stolen from https://github.com/twisted/treq/pull/49/files
+ # This applies to requests which don't set `Content-Length` or a
+ # `Transfer-Encoding` in the response because in this case the end of the
+ # response is indicated by the connection being closed, an event which may
+ # also be due to a transient network problem or other error. But since this
+ # behavior is expected of some servers (like YouTube), let's ignore it.
+ # Stolen from https://github.com/twisted/treq/pull/49/files
# http://twistedmatrix.com/trac/ticket/4840
self.deferred.callback(self.length)
else:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index cc4e258b0f..b00396fdc7 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -50,7 +50,7 @@ from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IBodyProducer, IResponse
+from twisted.web.iweb import IAgent, IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
@@ -72,6 +72,7 @@ from synapse.http.client import (
read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.proxyagent import ProxyAgent
from synapse.http.types import QueryParams
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -393,17 +394,32 @@ class MatrixFederationHttpClient:
if hs.config.server.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix)
- federation_agent = MatrixFederationAgent(
- self.reactor,
- tls_client_options_factory,
- user_agent.encode("ascii"),
- hs.config.server.federation_ip_range_allowlist,
- hs.config.server.federation_ip_range_blocklist,
+ outbound_federation_restricted_to = (
+ hs.config.worker.outbound_federation_restricted_to
)
+ if hs.get_instance_name() in outbound_federation_restricted_to:
+ # Talk to federation directly
+ federation_agent: IAgent = MatrixFederationAgent(
+ self.reactor,
+ tls_client_options_factory,
+ user_agent.encode("ascii"),
+ hs.config.server.federation_ip_range_allowlist,
+ hs.config.server.federation_ip_range_blocklist,
+ )
+ else:
+ # We need to talk to federation via the proxy via one of the configured
+ # locations
+ federation_proxies = outbound_federation_restricted_to.locations
+ federation_agent = ProxyAgent(
+ self.reactor,
+ self.reactor,
+ tls_client_options_factory,
+ federation_proxies=federation_proxies,
+ )
# Use a BlocklistingAgentWrapper to prevent circumventing the IP
# blocking via IP literals in server names
- self.agent = BlocklistingAgentWrapper(
+ self.agent: IAgent = BlocklistingAgentWrapper(
federation_agent,
ip_blocklist=hs.config.server.federation_ip_range_blocklist,
)
@@ -412,7 +428,6 @@ class MatrixFederationHttpClient:
self._store = hs.get_datastores().main
self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout_seconds = hs.config.federation.client_timeout_ms / 1000
-
self.max_long_retry_delay_seconds = (
hs.config.federation.max_long_retry_delay_ms / 1000
)
@@ -1141,6 +1156,101 @@ class MatrixFederationHttpClient:
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
+ json_dict, _ = await self.get_json_with_headers(
+ destination=destination,
+ path=path,
+ args=args,
+ retry_on_dns_fail=retry_on_dns_fail,
+ timeout=timeout,
+ ignore_backoff=ignore_backoff,
+ try_trailing_slash_on_400=try_trailing_slash_on_400,
+ parser=parser,
+ )
+ return json_dict
+
+ @overload
+ async def get_json_with_headers(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Literal[None] = None,
+ ) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]:
+ ...
+
+ @overload
+ async def get_json_with_headers(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryParams] = ...,
+ retry_on_dns_fail: bool = ...,
+ timeout: Optional[int] = ...,
+ ignore_backoff: bool = ...,
+ try_trailing_slash_on_400: bool = ...,
+ parser: ByteParser[T] = ...,
+ ) -> Tuple[T, Dict[bytes, List[bytes]]]:
+ ...
+
+ async def get_json_with_headers(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser[T]] = None,
+ ) -> Tuple[Union[JsonDict, T], Dict[bytes, List[bytes]]]:
+ """GETs some json from the given host homeserver and path
+
+ Args:
+ destination: The remote server to send the HTTP request to.
+
+ path: The HTTP path.
+
+ args: A dictionary used to create query strings, defaults to
+ None.
+
+ retry_on_dns_fail: true if the request should be retried on DNS failures
+
+ timeout: number of milliseconds to wait for the response.
+ self._default_timeout (60s) by default.
+
+ Note that we may make several attempts to send the request; this
+ timeout applies to the time spent waiting for response headers for
+ *each* attempt (including connection time) as well as the time spent
+ reading the response body after a 200 response.
+
+ ignore_backoff: true to ignore the historical backoff data
+ and try the request anyway.
+
+ try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
+ response we should try appending a trailing slash to the end of
+ the request. Workaround for #3622 in Synapse <= v0.99.3.
+
+ parser: The parser to use to decode the response. Defaults to
+ parsing as JSON.
+
+ Returns:
+ Succeeds when we get a 2xx HTTP response. The result will be a tuple of the
+ decoded JSON body and a dict of the response headers.
+
+ Raises:
+ HttpResponseException: If we get an HTTP response code >= 300
+ (except 429).
+ NotRetryingDestination: If we are not yet ready to retry this
+ server.
+ FederationDeniedError: If this destination is not on our
+ federation whitelist
+ RequestSendFailed: If there were problems connecting to the
+ remote, due to e.g. DNS failures, connection timeouts etc.
+ """
request = MatrixFederationRequest(
method="GET", destination=destination, path=path, query=args
)
@@ -1156,6 +1266,8 @@ class MatrixFederationHttpClient:
timeout=timeout,
)
+ headers = dict(response.headers.getAllRawHeaders())
+
if timeout is not None:
_sec_timeout = timeout / 1000
else:
@@ -1173,7 +1285,7 @@ class MatrixFederationHttpClient:
parser=parser,
)
- return body
+ return body, headers
async def delete_json(
self,
diff --git a/synapse/http/proxy.py b/synapse/http/proxy.py
new file mode 100644
index 0000000000..0874d67760
--- /dev/null
+++ b/synapse/http/proxy.py
@@ -0,0 +1,249 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import logging
+import urllib.parse
+from typing import TYPE_CHECKING, Any, Optional, Set, Tuple, cast
+
+from twisted.internet import protocol
+from twisted.internet.interfaces import ITCPTransport
+from twisted.internet.protocol import connectionDone
+from twisted.python import failure
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IAgent, IResponse
+from twisted.web.resource import IResource
+from twisted.web.server import Site
+
+from synapse.api.errors import Codes
+from synapse.http import QuieterFileBodyProducer
+from synapse.http.server import _AsyncResource
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import ISynapseReactor
+from synapse.util.async_helpers import timeout_deferred
+
+if TYPE_CHECKING:
+ from synapse.http.site import SynapseRequest
+
+logger = logging.getLogger(__name__)
+
+# "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616
+# section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be
+# consumed by the immediate recipient and not be forwarded on.
+HOP_BY_HOP_HEADERS = {
+ "Connection",
+ "Keep-Alive",
+ "Proxy-Authenticate",
+ "Proxy-Authorization",
+ "TE",
+ "Trailers",
+ "Transfer-Encoding",
+ "Upgrade",
+}
+
+
+def parse_connection_header_value(
+ connection_header_value: Optional[bytes],
+) -> Set[str]:
+ """
+ Parse the `Connection` header to determine which headers we should not be copied
+ over from the remote response.
+
+ As defined by RFC2616 section 14.10 and RFC9110 section 7.6.1
+
+ Example: `Connection: close, X-Foo, X-Bar` will return `{"Close", "X-Foo", "X-Bar"}`
+
+ Even though "close" is a special directive, let's just treat it as just another
+ header for simplicity. If people want to check for this directive, they can simply
+ check for `"Close" in headers`.
+
+ Args:
+ connection_header_value: The value of the `Connection` header.
+
+ Returns:
+ The set of header names that should not be copied over from the remote response.
+ The keys are capitalized in canonical capitalization.
+ """
+ headers = Headers()
+ extra_headers_to_remove: Set[str] = set()
+ if connection_header_value:
+ extra_headers_to_remove = {
+ headers._canonicalNameCaps(connection_option.strip()).decode("ascii")
+ for connection_option in connection_header_value.split(b",")
+ }
+
+ return extra_headers_to_remove
+
+
+class ProxyResource(_AsyncResource):
+ """
+ A stub resource that proxies any requests with a `matrix-federation://` scheme
+ through the given `federation_agent` to the remote homeserver and ferries back the
+ info.
+ """
+
+ isLeaf = True
+
+ def __init__(self, reactor: ISynapseReactor, federation_agent: IAgent):
+ super().__init__(True)
+
+ self.reactor = reactor
+ self.agent = federation_agent
+
+ async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]:
+ uri = urllib.parse.urlparse(request.uri)
+ assert uri.scheme == b"matrix-federation"
+
+ headers = Headers()
+ for header_name in (b"User-Agent", b"Authorization", b"Content-Type"):
+ header_value = request.getHeader(header_name)
+ if header_value:
+ headers.addRawHeader(header_name, header_value)
+
+ request_deferred = run_in_background(
+ self.agent.request,
+ request.method,
+ request.uri,
+ headers=headers,
+ bodyProducer=QuieterFileBodyProducer(request.content),
+ )
+ request_deferred = timeout_deferred(
+ request_deferred,
+ # This should be set longer than the timeout in `MatrixFederationHttpClient`
+ # so that it has enough time to complete and pass us the data before we give
+ # up.
+ timeout=90,
+ reactor=self.reactor,
+ )
+
+ response = await make_deferred_yieldable(request_deferred)
+
+ return response.code, response
+
+ def _send_response(
+ self,
+ request: "SynapseRequest",
+ code: int,
+ response_object: Any,
+ ) -> None:
+ response = cast(IResponse, response_object)
+ response_headers = cast(Headers, response.headers)
+
+ request.setResponseCode(code)
+
+ # The `Connection` header also defines which headers should not be copied over.
+ connection_header = response_headers.getRawHeaders(b"connection")
+ extra_headers_to_remove = parse_connection_header_value(
+ connection_header[0] if connection_header else None
+ )
+
+ # Copy headers.
+ for k, v in response_headers.getAllRawHeaders():
+ # Do not copy over any hop-by-hop headers. These are meant to only be
+ # consumed by the immediate recipient and not be forwarded on.
+ header_key = k.decode("ascii")
+ if (
+ header_key in HOP_BY_HOP_HEADERS
+ or header_key in extra_headers_to_remove
+ ):
+ continue
+
+ request.responseHeaders.setRawHeaders(k, v)
+
+ response.deliverBody(_ProxyResponseBody(request))
+
+ def _send_error_response(
+ self,
+ f: failure.Failure,
+ request: "SynapseRequest",
+ ) -> None:
+ request.setResponseCode(502)
+ request.setHeader(b"Content-Type", b"application/json")
+ request.write(
+ (
+ json.dumps(
+ {
+ "errcode": Codes.UNKNOWN,
+ "err": "ProxyResource: Error when proxying request: %s %s -> %s"
+ % (
+ request.method.decode("ascii"),
+ request.uri.decode("ascii"),
+ f,
+ ),
+ }
+ )
+ ).encode()
+ )
+ request.finish()
+
+
+class _ProxyResponseBody(protocol.Protocol):
+ """
+ A protocol that proxies the given remote response data back out to the given local
+ request.
+ """
+
+ transport: Optional[ITCPTransport] = None
+
+ def __init__(self, request: "SynapseRequest") -> None:
+ self._request = request
+
+ def dataReceived(self, data: bytes) -> None:
+ # Avoid sending response data to the local request that already disconnected
+ if self._request._disconnected and self.transport is not None:
+ # Close the connection (forcefully) since all the data will get
+ # discarded anyway.
+ self.transport.abortConnection()
+ return
+
+ self._request.write(data)
+
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
+ # If the local request is already finished (successfully or failed), don't
+ # worry about sending anything back.
+ if self._request.finished:
+ return
+
+ if reason.check(ResponseDone):
+ self._request.finish()
+ else:
+ # Abort the underlying request since our remote request also failed.
+ self._request.transport.abortConnection()
+
+
+class ProxySite(Site):
+ """
+ Proxies any requests with a `matrix-federation://` scheme through the given
+ `federation_agent`. Otherwise, behaves like a normal `Site`.
+ """
+
+ def __init__(
+ self,
+ resource: IResource,
+ reactor: ISynapseReactor,
+ federation_agent: IAgent,
+ ):
+ super().__init__(resource, reactor=reactor)
+
+ self._proxy_resource = ProxyResource(reactor, federation_agent)
+
+ def getResourceFor(self, request: "SynapseRequest") -> IResource:
+ uri = urllib.parse.urlparse(request.uri)
+ if uri.scheme == b"matrix-federation":
+ return self._proxy_resource
+
+ return super().getResourceFor(request)
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 7bdc4acae7..1fa3adbef2 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import random
import re
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple
from urllib.parse import urlparse
from urllib.request import ( # type: ignore[attr-defined]
getproxies_environment,
@@ -24,7 +25,12 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+ IProtocol,
+ IProtocolFactory,
+ IReactorCore,
+ IStreamClientEndpoint,
+)
from twisted.python.failure import Failure
from twisted.web.client import (
URI,
@@ -36,8 +42,10 @@ from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
+from synapse.config.workers import InstanceLocationConfig
from synapse.http import redact_uri
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
+from synapse.logging.context import run_in_background
logger = logging.getLogger(__name__)
@@ -74,6 +82,10 @@ class ProxyAgent(_AgentBase):
use_proxy: Whether proxy settings should be discovered and used
from conventional environment variables.
+ federation_proxies: An optional list of locations to proxy outbound federation
+ traffic through (only requests that use the `matrix-federation://` scheme
+ will be proxied).
+
Raises:
ValueError if use_proxy is set and the environment variables
contain an invalid proxy specification.
@@ -89,6 +101,7 @@ class ProxyAgent(_AgentBase):
bindAddress: Optional[bytes] = None,
pool: Optional[HTTPConnectionPool] = None,
use_proxy: bool = False,
+ federation_proxies: Collection[InstanceLocationConfig] = (),
):
contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
@@ -127,6 +140,27 @@ class ProxyAgent(_AgentBase):
self._policy_for_https = contextFactory
self._reactor = reactor
+ self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None
+ if federation_proxies:
+ endpoints = []
+ for federation_proxy in federation_proxies:
+ endpoint = HostnameEndpoint(
+ self.proxy_reactor,
+ federation_proxy.host,
+ federation_proxy.port,
+ )
+
+ if federation_proxy.tls:
+ tls_connection_creator = self._policy_for_https.creatorForNetloc(
+ federation_proxy.host,
+ federation_proxy.port,
+ )
+ endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+
+ endpoints.append(endpoint)
+
+ self._federation_proxy_endpoint = _ProxyEndpoints(endpoints)
+
def request(
self,
method: bytes,
@@ -214,6 +248,14 @@ class ProxyAgent(_AgentBase):
parsed_uri.port,
self.https_proxy_creds,
)
+ elif (
+ parsed_uri.scheme == b"matrix-federation"
+ and self._federation_proxy_endpoint
+ ):
+ # Cache *all* connections under the same key, since we are only
+ # connecting to a single destination, the proxy:
+ endpoint = self._federation_proxy_endpoint
+ request_path = uri
else:
# not using a proxy
endpoint = HostnameEndpoint(
@@ -233,6 +275,11 @@ class ProxyAgent(_AgentBase):
endpoint = wrapClientTLS(tls_connection_creator, endpoint)
elif parsed_uri.scheme == b"http":
pass
+ elif (
+ parsed_uri.scheme == b"matrix-federation"
+ and self._federation_proxy_endpoint
+ ):
+ pass
else:
return defer.fail(
Failure(
@@ -337,3 +384,31 @@ def parse_proxy(
credentials = ProxyCredentials(b"".join([url.username, b":", url.password]))
return url.scheme, url.hostname, url.port or default_port, credentials
+
+
+@implementer(IStreamClientEndpoint)
+class _ProxyEndpoints:
+ """An endpoint that randomly iterates through a given list of endpoints at
+ each connection attempt.
+ """
+
+ def __init__(self, endpoints: Sequence[IStreamClientEndpoint]) -> None:
+ assert endpoints
+ self._endpoints = endpoints
+
+ def connect(
+ self, protocol_factory: IProtocolFactory
+ ) -> "defer.Deferred[IProtocol]":
+ """Implements IStreamClientEndpoint interface"""
+
+ return run_in_background(self._do_connect, protocol_factory)
+
+ async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
+ failures: List[Failure] = []
+ for endpoint in random.sample(self._endpoints, k=len(self._endpoints)):
+ try:
+ return await endpoint.connect(protocol_factory)
+ except Exception:
+ failures.append(Failure())
+
+ failures.pop().raiseException()
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 933172c873..ff3153a9d9 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -18,6 +18,7 @@ import html
import logging
import types
import urllib
+import urllib.parse
from http import HTTPStatus
from http.client import FOUND
from inspect import isawaitable
@@ -65,7 +66,6 @@ from synapse.api.errors import (
UnrecognizedRequestError,
)
from synapse.config.homeserver import HomeServerConfig
-from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
from synapse.util import json_encoder
@@ -76,6 +76,7 @@ from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
import opentracing
+ from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -102,7 +103,7 @@ HTTP_STATUS_REQUEST_CANCELLED = 499
def return_json_error(
- f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig]
+ f: failure.Failure, request: "SynapseRequest", config: Optional[HomeServerConfig]
) -> None:
"""Sends a JSON error response to clients."""
@@ -220,8 +221,8 @@ def return_html_error(
def wrap_async_request_handler(
- h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
-) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
+ h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]]
+) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed
@@ -235,7 +236,7 @@ def wrap_async_request_handler(
"""
async def wrapped_async_request_handler(
- self: "_AsyncResource", request: SynapseRequest
+ self: "_AsyncResource", request: "SynapseRequest"
) -> None:
with request.processing():
await h(self, request)
@@ -300,7 +301,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
self._extract_context = extract_context
- def render(self, request: SynapseRequest) -> int:
+ def render(self, request: "SynapseRequest") -> int:
"""This gets called by twisted every time someone sends us a request."""
request.render_deferred = defer.ensureDeferred(
self._async_render_wrapper(request)
@@ -308,7 +309,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
return NOT_DONE_YET
@wrap_async_request_handler
- async def _async_render_wrapper(self, request: SynapseRequest) -> None:
+ async def _async_render_wrapper(self, request: "SynapseRequest") -> None:
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
@@ -326,9 +327,15 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
+ logger.exception(
+ "Error handling request",
+ exc_info=(f.type, f.value, f.getTracebackObject()),
+ )
self._send_error_response(f, request)
- async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
+ async def _async_render(
+ self, request: "SynapseRequest"
+ ) -> Optional[Tuple[int, Any]]:
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for
different routing.
@@ -358,7 +365,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
@abc.abstractmethod
def _send_response(
self,
- request: SynapseRequest,
+ request: "SynapseRequest",
code: int,
response_object: Any,
) -> None:
@@ -368,7 +375,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
def _send_error_response(
self,
f: failure.Failure,
- request: SynapseRequest,
+ request: "SynapseRequest",
) -> None:
raise NotImplementedError()
@@ -384,7 +391,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response(
self,
- request: SynapseRequest,
+ request: "SynapseRequest",
code: int,
response_object: Any,
) -> None:
@@ -401,7 +408,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_error_response(
self,
f: failure.Failure,
- request: SynapseRequest,
+ request: "SynapseRequest",
) -> None:
"""Implements _AsyncResource._send_error_response"""
return_json_error(f, request, None)
@@ -473,7 +480,7 @@ class JsonResource(DirectServeJsonResource):
)
def _get_handler_for_request(
- self, request: SynapseRequest
+ self, request: "SynapseRequest"
) -> Tuple[ServletCallback, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.
@@ -503,7 +510,7 @@ class JsonResource(DirectServeJsonResource):
# Huh. No one wanted to handle that? Fiiiiiine.
raise UnrecognizedRequestError(code=404)
- async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
+ async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
request.is_render_cancellable = is_function_cancellable(callback)
@@ -535,7 +542,7 @@ class JsonResource(DirectServeJsonResource):
def _send_error_response(
self,
f: failure.Failure,
- request: SynapseRequest,
+ request: "SynapseRequest",
) -> None:
"""Implements _AsyncResource._send_error_response"""
return_json_error(f, request, self.hs.config)
@@ -551,7 +558,7 @@ class DirectServeHtmlResource(_AsyncResource):
def _send_response(
self,
- request: SynapseRequest,
+ request: "SynapseRequest",
code: int,
response_object: Any,
) -> None:
@@ -565,7 +572,7 @@ class DirectServeHtmlResource(_AsyncResource):
def _send_error_response(
self,
f: failure.Failure,
- request: SynapseRequest,
+ request: "SynapseRequest",
) -> None:
"""Implements _AsyncResource._send_error_response"""
return_html_error(f, request, self.ERROR_TEMPLATE)
@@ -592,7 +599,7 @@ class UnrecognizedRequestResource(resource.Resource):
errcode of M_UNRECOGNIZED.
"""
- def render(self, request: SynapseRequest) -> int:
+ def render(self, request: "SynapseRequest") -> int:
f = failure.Failure(UnrecognizedRequestError(code=404))
return_json_error(f, request, None)
# A response has already been sent but Twisted requires either NOT_DONE_YET
@@ -622,7 +629,7 @@ class RootRedirect(resource.Resource):
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
- def render_OPTIONS(self, request: SynapseRequest) -> bytes:
+ def render_OPTIONS(self, request: "SynapseRequest") -> bytes:
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")
@@ -737,7 +744,7 @@ def _encode_json_bytes(json_object: object) -> bytes:
def respond_with_json(
- request: SynapseRequest,
+ request: "SynapseRequest",
code: int,
json_object: Any,
send_cors: bool = False,
@@ -787,7 +794,7 @@ def respond_with_json(
def respond_with_json_bytes(
- request: SynapseRequest,
+ request: "SynapseRequest",
code: int,
json_bytes: bytes,
send_cors: bool = False,
@@ -825,7 +832,7 @@ def respond_with_json_bytes(
async def _async_write_json_to_request_in_thread(
- request: SynapseRequest,
+ request: "SynapseRequest",
json_encoder: Callable[[Any], bytes],
json_object: Any,
) -> None:
@@ -883,7 +890,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator)
-def set_cors_headers(request: SynapseRequest) -> None:
+def set_cors_headers(request: "SynapseRequest") -> None:
"""Set the CORS headers so that javascript running in a web browsers can
use this API
@@ -981,7 +988,7 @@ def set_clickjacking_protection_headers(request: Request) -> None:
def respond_with_redirect(
- request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False
+ request: "SynapseRequest", url: bytes, statusCode: int = FOUND, cors: bool = False
) -> None:
"""
Write a 302 (or other specified status code) response to the request, if it is still alive.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5b5a7c1e59..0ee2598345 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -21,25 +21,28 @@ from zope.interface import implementer
from twisted.internet.address import UNIXAddress
from twisted.internet.defer import Deferred
-from twisted.internet.interfaces import IAddress, IReactorTime
+from twisted.internet.interfaces import IAddress
from twisted.python.failure import Failure
from twisted.web.http import HTTPChannel
+from twisted.web.iweb import IAgent
from twisted.web.resource import IResource, Resource
-from twisted.web.server import Request, Site
+from twisted.web.server import Request
from synapse.config.server import ListenerConfig
from synapse.http import get_request_user_agent, redact_uri
+from synapse.http.proxy import ProxySite
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import (
ContextRequest,
LoggingContext,
PreserveLoggingContext,
)
-from synapse.types import Requester
+from synapse.types import ISynapseReactor, Requester
if TYPE_CHECKING:
import opentracing
+
logger = logging.getLogger(__name__)
_next_request_seq = 0
@@ -102,7 +105,7 @@ class SynapseRequest(Request):
# A boolean indicating whether `render_deferred` should be cancelled if the
# client disconnects early. Expected to be set by the coroutine started by
# `Resource.render`, if rendering is asynchronous.
- self.is_render_cancellable = False
+ self.is_render_cancellable: bool = False
global _next_request_seq
self.request_seq = _next_request_seq
@@ -601,7 +604,7 @@ class _XForwardedForAddress:
host: str
-class SynapseSite(Site):
+class SynapseSite(ProxySite):
"""
Synapse-specific twisted http Site
@@ -623,7 +626,8 @@ class SynapseSite(Site):
resource: IResource,
server_version_string: str,
max_request_body_size: int,
- reactor: IReactorTime,
+ reactor: ISynapseReactor,
+ federation_agent: IAgent,
):
"""
@@ -638,7 +642,11 @@ class SynapseSite(Site):
dropping the connection
reactor: reactor to be used to manage connection timeouts
"""
- Site.__init__(self, resource, reactor=reactor)
+ super().__init__(
+ resource=resource,
+ reactor=reactor,
+ federation_agent=federation_agent,
+ )
self.site_tag = site_tag
self.reactor = reactor
@@ -649,7 +657,9 @@ class SynapseSite(Site):
request_id_header = config.http_options.request_id_header
- self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886
+ self.experimental_cors_msc3886: bool = (
+ config.http_options.experimental_cors_msc3886
+ )
def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class(
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 5a965f233b..21c5309740 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -31,9 +31,7 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver(
- federation_http_client=None, homeserver_to_use=GenericWorkerServer
- )
+ hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
return hs
def default_config(self) -> JsonDict:
@@ -91,9 +89,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
@patch("synapse.app.homeserver.KeyResource", new=Mock())
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver(
- federation_http_client=None, homeserver_to_use=SynapseHomeServer
- )
+ hs = self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
return hs
@parameterized.expand(
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ee48f9e546..66215af2b8 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -41,7 +41,6 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.appservice_api = mock.Mock()
hs = self.setup_test_homeserver(
"server",
- federation_http_client=None,
application_service_api=self.appservice_api,
)
handler = hs.get_device_handler()
@@ -401,7 +400,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver("server", federation_http_client=None)
+ hs = self.setup_test_homeserver("server")
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bf0862ed54..5f11d5df11 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -57,7 +57,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver(federation_http_client=None)
+ hs = self.setup_test_homeserver()
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
return hs
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 19f5322317..fd66d573d2 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -993,7 +993,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
"server",
- federation_http_client=None,
federation_sender=Mock(spec=FederationSender),
)
return hs
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 94518a7196..5da1d95f0b 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -17,6 +17,8 @@ import json
from typing import Dict, List, Set
from unittest.mock import ANY, Mock, call
+from netaddr import IPSet
+
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
@@ -24,6 +26,7 @@ from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.handlers.typing import TypingWriterHandler
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
@@ -76,6 +79,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# we mock out the federation client too
self.mock_federation_client = Mock(spec=["put_json"])
self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+ self.mock_federation_client.agent = MatrixFederationAgent(
+ reactor,
+ tls_client_options_factory=None,
+ user_agent=b"SynapseInTrialTest/0.0.0",
+ ip_allowlist=None,
+ ip_blocklist=IPSet(),
+ )
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index b5f4a60fe5..a8b9737d1f 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Generator
-from unittest.mock import Mock
+from typing import Any, Dict, Generator
+from unittest.mock import ANY, Mock, create_autospec
from netaddr import IPSet
from parameterized import parameterized
@@ -21,10 +21,11 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred, TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
from twisted.test.proto_helpers import MemoryReactor, StringTransport
-from twisted.web.client import ResponseNeverReceived
+from twisted.web.client import Agent, ResponseNeverReceived
from twisted.web.http import HTTPChannel
+from twisted.web.http_headers import Headers
-from synapse.api.errors import RequestSendFailed
+from synapse.api.errors import HttpResponseException, RequestSendFailed
from synapse.http.matrixfederationclient import (
ByteParser,
MatrixFederationHttpClient,
@@ -39,7 +40,9 @@ from synapse.logging.context import (
from synapse.server import HomeServer
from synapse.util import Clock
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeTransport
+from tests.test_utils import FakeResponse
from tests.unittest import HomeserverTestCase, override_config
@@ -658,3 +661,181 @@ class FederationClientTests(HomeserverTestCase):
self.assertEqual(self.cl.max_short_retry_delay_seconds, 7)
self.assertEqual(self.cl.max_long_retries, 20)
self.assertEqual(self.cl.max_short_retries, 5)
+
+
+class FederationClientProxyTests(BaseMultiWorkerStreamTestCase):
+ def default_config(self) -> Dict[str, Any]:
+ conf = super().default_config()
+ conf["instance_map"] = {
+ "main": {"host": "testserv", "port": 8765},
+ "federation_sender": {"host": "testserv", "port": 1001},
+ }
+ return conf
+
+ @override_config({"outbound_federation_restricted_to": ["federation_sender"]})
+ def test_proxy_requests_through_federation_sender_worker(self) -> None:
+ """
+ Test that all outbound federation requests go through the `federation_sender`
+ worker
+ """
+ # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
+ # so we can act like some remote server responding to requests
+ mock_client_on_federation_sender = Mock()
+ mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
+ mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
+
+ # Create the `federation_sender` worker
+ self.federation_sender = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ {"worker_name": "federation_sender"},
+ federation_http_client=mock_client_on_federation_sender,
+ )
+
+ # Fake `remoteserv:8008` responding to requests
+ mock_agent_on_federation_sender.request.side_effect = (
+ lambda *args, **kwargs: defer.succeed(
+ FakeResponse.json(
+ payload={
+ "foo": "bar",
+ }
+ )
+ )
+ )
+
+ # This federation request from the main process should be proxied through the
+ # `federation_sender` worker off to the remote server
+ test_request_from_main_process_d = defer.ensureDeferred(
+ self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar")
+ )
+
+ # Pump the reactor so our deferred goes through the motions
+ self.pump()
+
+ # Make sure that the request was proxied through the `federation_sender` worker
+ mock_agent_on_federation_sender.request.assert_called_once_with(
+ b"GET",
+ b"matrix-federation://remoteserv:8008/foo/bar",
+ headers=ANY,
+ bodyProducer=ANY,
+ )
+
+ # Make sure the response is as expected back on the main worker
+ res = self.successResultOf(test_request_from_main_process_d)
+ self.assertEqual(res, {"foo": "bar"})
+
+ @override_config({"outbound_federation_restricted_to": ["federation_sender"]})
+ def test_proxy_request_with_network_error_through_federation_sender_worker(
+ self,
+ ) -> None:
+ """
+ Test that when the outbound federation request fails with a network related
+ error, a sensible error makes its way back to the main process.
+ """
+ # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
+ # so we can act like some remote server responding to requests
+ mock_client_on_federation_sender = Mock()
+ mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
+ mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
+
+ # Create the `federation_sender` worker
+ self.federation_sender = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ {"worker_name": "federation_sender"},
+ federation_http_client=mock_client_on_federation_sender,
+ )
+
+ # Fake `remoteserv:8008` responding to requests
+ mock_agent_on_federation_sender.request.side_effect = (
+ lambda *args, **kwargs: defer.fail(ResponseNeverReceived("fake error"))
+ )
+
+ # This federation request from the main process should be proxied through the
+ # `federation_sender` worker off to the remote server
+ test_request_from_main_process_d = defer.ensureDeferred(
+ self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar")
+ )
+
+ # Pump the reactor so our deferred goes through the motions. We pump with 10
+ # seconds (0.1 * 100) so the `MatrixFederationHttpClient` runs out of retries
+ # and finally passes along the error response.
+ self.pump(0.1)
+
+ # Make sure that the request was proxied through the `federation_sender` worker
+ mock_agent_on_federation_sender.request.assert_called_with(
+ b"GET",
+ b"matrix-federation://remoteserv:8008/foo/bar",
+ headers=ANY,
+ bodyProducer=ANY,
+ )
+
+ # Make sure we get some sort of error back on the main worker
+ failure_res = self.failureResultOf(test_request_from_main_process_d)
+ self.assertIsInstance(failure_res.value, RequestSendFailed)
+ self.assertIsInstance(failure_res.value.inner_exception, HttpResponseException)
+
+ @override_config({"outbound_federation_restricted_to": ["federation_sender"]})
+ def test_proxy_requests_and_discards_hop_by_hop_headers(self) -> None:
+ """
+ Test to make sure hop-by-hop headers and addional headers defined in the
+ `Connection` header are discarded when proxying requests
+ """
+ # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance
+ # so we can act like some remote server responding to requests
+ mock_client_on_federation_sender = Mock()
+ mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True)
+ mock_client_on_federation_sender.agent = mock_agent_on_federation_sender
+
+ # Create the `federation_sender` worker
+ self.federation_sender = self.make_worker_hs(
+ "synapse.app.generic_worker",
+ {"worker_name": "federation_sender"},
+ federation_http_client=mock_client_on_federation_sender,
+ )
+
+ # Fake `remoteserv:8008` responding to requests
+ mock_agent_on_federation_sender.request.side_effect = lambda *args, **kwargs: defer.succeed(
+ FakeResponse(
+ code=200,
+ body=b'{"foo": "bar"}',
+ headers=Headers(
+ {
+ "Content-Type": ["application/json"],
+ "Connection": ["close, X-Foo, X-Bar"],
+ # Should be removed because it's defined in the `Connection` header
+ "X-Foo": ["foo"],
+ "X-Bar": ["bar"],
+ # Should be removed because it's a hop-by-hop header
+ "Proxy-Authorization": "abcdef",
+ }
+ ),
+ )
+ )
+
+ # This federation request from the main process should be proxied through the
+ # `federation_sender` worker off to the remote server
+ test_request_from_main_process_d = defer.ensureDeferred(
+ self.hs.get_federation_http_client().get_json_with_headers(
+ "remoteserv:8008", "foo/bar"
+ )
+ )
+
+ # Pump the reactor so our deferred goes through the motions
+ self.pump()
+
+ # Make sure that the request was proxied through the `federation_sender` worker
+ mock_agent_on_federation_sender.request.assert_called_once_with(
+ b"GET",
+ b"matrix-federation://remoteserv:8008/foo/bar",
+ headers=ANY,
+ bodyProducer=ANY,
+ )
+
+ res, headers = self.successResultOf(test_request_from_main_process_d)
+ header_names = set(headers.keys())
+
+ # Make sure the response does not include the hop-by-hop headers
+ self.assertNotIn(b"X-Foo", header_names)
+ self.assertNotIn(b"X-Bar", header_names)
+ self.assertNotIn(b"Proxy-Authorization", header_names)
+ # Make sure the response is as expected back on the main worker
+ self.assertEqual(res, {"foo": "bar"})
diff --git a/tests/http/test_proxy.py b/tests/http/test_proxy.py
new file mode 100644
index 0000000000..0dc9ba8e05
--- /dev/null
+++ b/tests/http/test_proxy.py
@@ -0,0 +1,53 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Set
+
+from parameterized import parameterized
+
+from synapse.http.proxy import parse_connection_header_value
+
+from tests.unittest import TestCase
+
+
+class ProxyTests(TestCase):
+ @parameterized.expand(
+ [
+ [b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}],
+ # No whitespace
+ [b"close,X-Foo,X-Bar", {"Close", "X-Foo", "X-Bar"}],
+ # More whitespace
+ [b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}],
+ # "close" directive in not the first position
+ [b"X-Foo, X-Bar, close", {"X-Foo", "X-Bar", "Close"}],
+ # Normalizes header capitalization
+ [b"keep-alive, x-fOo, x-bAr", {"Keep-Alive", "X-Foo", "X-Bar"}],
+ # Handles header names with whitespace
+ [
+ b"keep-alive, x foo, x bar",
+ {"Keep-Alive", "X foo", "X bar"},
+ ],
+ ]
+ )
+ def test_parse_connection_header_value(
+ self,
+ connection_header_value: bytes,
+ expected_extra_headers_to_remove: Set[str],
+ ) -> None:
+ """
+ Tests that the connection header value is parsed correctly
+ """
+ self.assertEqual(
+ expected_extra_headers_to_remove,
+ parse_connection_header_value(connection_header_value),
+ )
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index eb9b1f1cd9..96badc46b0 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -69,10 +69,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
- federation_http_client=None,
homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(),
reactor=self.reactor,
+ federation_http_client=None,
)
# Since we use sqlite in memory databases we need to make sure the
@@ -380,6 +380,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
server_version_string="1",
max_request_body_size=8192,
reactor=self.reactor,
+ federation_agent=worker_hs.get_federation_http_client().agent,
)
worker_hs.get_replication_command_handler().start_replication(worker_hs)
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 08703206a9..a324b4d31d 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -14,14 +14,18 @@
import logging
from unittest.mock import Mock
+from netaddr import IPSet
+
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
from synapse.handlers.typing import TypingWriterHandler
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client import login, room
from synapse.types import UserID, create_requester
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import get_clock
from tests.test_utils import make_awaitable
logger = logging.getLogger(__name__)
@@ -41,13 +45,25 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
room.register_servlets,
]
+ def setUp(self) -> None:
+ super().setUp()
+
+ reactor, _ = get_clock()
+ self.matrix_federation_agent = MatrixFederationAgent(
+ reactor,
+ tls_client_options_factory=None,
+ user_agent=b"SynapseInTrialTest/0.0.0",
+ ip_allowlist=None,
+ ip_blocklist=IPSet(),
+ )
+
def test_send_event_single_sender(self) -> None:
"""Test that using a single federation sender worker correctly sends a
new event.
"""
mock_client = Mock(spec=["put_json"])
mock_client.put_json.return_value = make_awaitable({})
-
+ mock_client.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",
{
@@ -78,6 +94,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"""
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
+ mock_client1.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",
{
@@ -92,6 +109,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
+ mock_client2.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",
{
@@ -145,6 +163,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"""
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
+ mock_client1.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",
{
@@ -159,6 +178,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
+ mock_client2.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",
{
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index dcbb125a3b..e12098102b 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -40,7 +40,6 @@ class PresenceTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(
"red",
- federation_http_client=None,
federation_client=Mock(),
presence_handler=self.presence_handler,
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index f1b4e1ad2f..d013e75d55 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -67,8 +67,6 @@ class RoomBase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver(
"red",
- federation_http_client=None,
- federation_client=Mock(),
)
self.hs.get_federation_handler = Mock() # type: ignore[assignment]
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 9cb326d90a..f6df31aba4 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -31,7 +31,7 @@ room_key: RoomKey = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver("server", federation_http_client=None)
+ hs = self.setup_test_homeserver("server")
self.store = hs.get_datastores().main
return hs
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 857e2caf2e..0282673167 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase):
servlets = [room.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver("server", federation_http_client=None)
+ hs = self.setup_test_homeserver("server")
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index 6861d3a6c9..809c9f175d 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -45,9 +45,7 @@ def fake_listdir(filepath: str) -> List[str]:
class WorkerSchemaTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver(
- federation_http_client=None, homeserver_to_use=GenericWorkerServer
- )
+ hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
return hs
def default_config(self) -> JsonDict:
diff --git a/tests/test_server.py b/tests/test_server.py
index e266c06a2c..fe5afebdcd 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -38,7 +38,7 @@ from tests.http.server._base import test_disconnect
from tests.server import (
FakeChannel,
FakeSite,
- ThreadedMemoryReactorClock,
+ get_clock,
make_request,
setup_test_homeserver,
)
@@ -46,12 +46,11 @@ from tests.server import (
class JsonResourceTests(unittest.TestCase):
def setUp(self) -> None:
- self.reactor = ThreadedMemoryReactorClock()
- self.hs_clock = Clock(self.reactor)
+ reactor, clock = get_clock()
+ self.reactor = reactor
self.homeserver = setup_test_homeserver(
self.addCleanup,
- federation_http_client=None,
- clock=self.hs_clock,
+ clock=clock,
reactor=self.reactor,
)
@@ -209,7 +208,13 @@ class JsonResourceTests(unittest.TestCase):
class OptionsResourceTests(unittest.TestCase):
def setUp(self) -> None:
- self.reactor = ThreadedMemoryReactorClock()
+ reactor, clock = get_clock()
+ self.reactor = reactor
+ self.homeserver = setup_test_homeserver(
+ self.addCleanup,
+ clock=clock,
+ reactor=self.reactor,
+ )
class DummyResource(Resource):
isLeaf = True
@@ -242,6 +247,7 @@ class OptionsResourceTests(unittest.TestCase):
"1.0",
max_request_body_size=4096,
reactor=self.reactor,
+ federation_agent=self.homeserver.get_federation_http_client().agent,
)
# render the request and return the channel
@@ -344,7 +350,8 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
await self.callback(request)
def setUp(self) -> None:
- self.reactor = ThreadedMemoryReactorClock()
+ reactor, _ = get_clock()
+ self.reactor = reactor
def test_good_response(self) -> None:
async def callback(request: SynapseRequest) -> None:
@@ -462,9 +469,9 @@ class DirectServeJsonResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeJsonResource` cancellation."""
def setUp(self) -> None:
- self.reactor = ThreadedMemoryReactorClock()
- self.clock = Clock(self.reactor)
- self.resource = CancellableDirectServeJsonResource(self.clock)
+ reactor, clock = get_clock()
+ self.reactor = reactor
+ self.resource = CancellableDirectServeJsonResource(clock)
self.site = FakeSite(self.resource, self.reactor)
def test_cancellable_disconnect(self) -> None:
@@ -496,9 +503,9 @@ class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeHtmlResource` cancellation."""
def setUp(self) -> None:
- self.reactor = ThreadedMemoryReactorClock()
- self.clock = Clock(self.reactor)
- self.resource = CancellableDirectServeHtmlResource(self.clock)
+ reactor, clock = get_clock()
+ self.reactor = reactor
+ self.resource = CancellableDirectServeHtmlResource(clock)
self.site = FakeSite(self.resource, self.reactor)
def test_cancellable_disconnect(self) -> None:
diff --git a/tests/unittest.py b/tests/unittest.py
index c73195b32b..334a95a917 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -358,6 +358,7 @@ class HomeserverTestCase(TestCase):
server_version_string="1",
max_request_body_size=4096,
reactor=self.reactor,
+ federation_agent=self.hs.get_federation_http_client().agent,
)
from tests.rest.client.utils import RestHelper
|