diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 71a15f434d..c208185791 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,21 +14,21 @@
# limitations under the License.
import logging
+import urllib
-import attr
-from netaddr import IPAddress
+from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent
+from twisted.web.iweb import IAgent, IAgentEndpointFactory
-from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -36,8 +36,9 @@ logger = logging.getLogger(__name__)
@implementer(IAgent)
class MatrixFederationAgent(object):
- """An Agent-like thing which provides a `request` method which will look up a matrix
- server and send an HTTP request to it.
+ """An Agent-like thing which provides a `request` method which correctly
+ handles resolving matrix server names when using matrix://. Handles standard
+ https URIs as normal.
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
@@ -65,23 +66,25 @@ class MatrixFederationAgent(object):
):
self._reactor = reactor
self._clock = Clock(reactor)
-
- self._tls_client_options_factory = tls_client_options_factory
- if _srv_resolver is None:
- _srv_resolver = SrvResolver()
- self._srv_resolver = _srv_resolver
-
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
+ self._agent = Agent.usingEndpointFactory(
+ self._reactor,
+ MatrixHostnameEndpointFactory(
+ reactor, tls_client_options_factory, _srv_resolver
+ ),
+ pool=self._pool,
+ )
+
self._well_known_resolver = WellKnownResolver(
self._reactor,
agent=Agent(
self._reactor,
- pool=self._pool,
contextFactory=tls_client_options_factory,
+ pool=self._pool,
),
well_known_cache=_well_known_cache,
)
@@ -91,19 +94,15 @@ class MatrixFederationAgent(object):
"""
Args:
method (bytes): HTTP method: GET/POST/etc
-
uri (bytes): Absolute URI to be retrieved
-
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
-
bodyProducer (twisted.web.iweb.IBodyProducer|None):
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
no body.
-
Returns:
Deferred[twisted.web.iweb.IResponse]:
fires when the header of the response has been received (regardless of the
@@ -111,210 +110,195 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request
from being sent).
"""
- parsed_uri = URI.fromBytes(uri, defaultPort=-1)
- res = yield self._route_matrix_uri(parsed_uri)
+ # We use urlparse as that will set `port` to None if there is no
+ # explicit port.
+ parsed_uri = urllib.parse.urlparse(uri)
- # set up the TLS connection params
+ # If this is a matrix:// URI check if the server has delegated matrix
+ # traffic using well-known delegation.
#
- # XXX disabling TLS is really only supported here for the benefit of the
- # unit tests. We should make the UTs cope with TLS rather than having to make
- # the code support the unit tests.
- if self._tls_client_options_factory is None:
- tls_options = None
- else:
- tls_options = self._tls_client_options_factory.get_options(
- res.tls_server_name.decode("ascii")
+ # We have to do this here and not in the endpoint as we need to rewrite
+ # the host header with the delegated server name.
+ delegated_server = None
+ if (
+ parsed_uri.scheme == b"matrix"
+ and not _is_ip_literal(parsed_uri.hostname)
+ and not parsed_uri.port
+ ):
+ well_known_result = yield self._well_known_resolver.get_well_known(
+ parsed_uri.hostname
+ )
+ delegated_server = well_known_result.delegated_server
+
+ if delegated_server:
+ # Ok, the server has delegated matrix traffic to somewhere else, so
+ # lets rewrite the URL to replace the server with the delegated
+ # server name.
+ uri = urllib.parse.urlunparse(
+ (
+ parsed_uri.scheme,
+ delegated_server,
+ parsed_uri.path,
+ parsed_uri.params,
+ parsed_uri.query,
+ parsed_uri.fragment,
+ )
)
+ parsed_uri = urllib.parse.urlparse(uri)
- # make sure that the Host header is set correctly
+ # We need to make sure the host header is set to the netloc of the
+ # server.
if headers is None:
headers = Headers()
else:
headers = headers.copy()
if not headers.hasHeader(b"host"):
- headers.addRawHeader(b"host", res.host_header)
+ headers.addRawHeader(b"host", parsed_uri.netloc)
- class EndpointFactory(object):
- @staticmethod
- def endpointForURI(_uri):
- ep = LoggingHostnameEndpoint(
- self._reactor, res.target_host, res.target_port
- )
- if tls_options is not None:
- ep = wrapClientTLS(tls_options, ep)
- return ep
+ with PreserveLoggingContext():
+ res = yield self._agent.request(method, uri, headers, bodyProducer)
- agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
- res = yield make_deferred_yieldable(
- agent.request(method, uri, headers, bodyProducer)
- )
return res
- @defer.inlineCallbacks
- def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
- """Helper for `request`: determine the routing for a Matrix URI
- Args:
- parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
- parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
- if there is no explicit port given.
+@implementer(IAgentEndpointFactory)
+class MatrixHostnameEndpointFactory(object):
+ """Factory for MatrixHostnameEndpoint for parsing to an Agent.
+ """
- lookup_well_known (bool): True if we should look up the .well-known file if
- there is no SRV record.
+ def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+ self._reactor = reactor
+ self._tls_client_options_factory = tls_client_options_factory
- Returns:
- Deferred[_RoutingResult]
- """
- # check for an IP literal
- try:
- ip_address = IPAddress(parsed_uri.host.decode("ascii"))
- except Exception:
- # not an IP address
- ip_address = None
-
- if ip_address:
- port = parsed_uri.port
- if port == -1:
- port = 8448
- return _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=port,
- )
+ if srv_resolver is None:
+ srv_resolver = SrvResolver()
- if parsed_uri.port != -1:
- # there is an explicit port
- return _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=parsed_uri.port,
- )
+ self._srv_resolver = srv_resolver
- if lookup_well_known:
- # try a .well-known lookup
- well_known_result = yield self._well_known_resolver.get_well_known(
- parsed_uri.host
- )
- well_known_server = well_known_result.delegated_server
-
- if well_known_server:
- # if we found a .well-known, start again, but don't do another
- # .well-known lookup.
-
- # parse the server name in the .well-known response into host/port.
- # (This code is lifted from twisted.web.client.URI.fromBytes).
- if b":" in well_known_server:
- well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
- try:
- well_known_port = int(well_known_port)
- except ValueError:
- # the part after the colon could not be parsed as an int
- # - we assume it is an IPv6 literal with no port (the closing
- # ']' stops it being parsed as an int)
- well_known_host, well_known_port = well_known_server, -1
- else:
- well_known_host, well_known_port = well_known_server, -1
-
- new_uri = URI(
- scheme=parsed_uri.scheme,
- netloc=well_known_server,
- host=well_known_host,
- port=well_known_port,
- path=parsed_uri.path,
- params=parsed_uri.params,
- query=parsed_uri.query,
- fragment=parsed_uri.fragment,
- )
+ def endpointForURI(self, parsed_uri):
+ return MatrixHostnameEndpoint(
+ self._reactor,
+ self._tls_client_options_factory,
+ self._srv_resolver,
+ parsed_uri,
+ )
- res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
- return res
-
- # try a SRV lookup
- service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
- server_list = yield self._srv_resolver.resolve_service(service_name)
-
- if not server_list:
- target_host = parsed_uri.host
- port = 8448
- logger.debug(
- "No SRV record for %s, using %s:%i",
- parsed_uri.host.decode("ascii"),
- target_host.decode("ascii"),
- port,
- )
+
+@implementer(IStreamClientEndpoint)
+class MatrixHostnameEndpoint(object):
+ """An endpoint that resolves matrix:// URLs using Matrix server name
+ resolution (i.e. via SRV). Does not check for well-known delegation.
+ """
+
+ def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+ self._reactor = reactor
+
+ # We reparse the URI so that defaultPort is -1 rather than 80
+ self._parsed_uri = parsed_uri
+
+ # set up the TLS connection params
+ #
+ # XXX disabling TLS is really only supported here for the benefit of the
+ # unit tests. We should make the UTs cope with TLS rather than having to make
+ # the code support the unit tests.
+
+ if tls_client_options_factory is None:
+ self._tls_options = None
else:
- target_host, port = pick_server_from_list(server_list)
- logger.debug(
- "Picked %s:%i from SRV records for %s",
- target_host.decode("ascii"),
- port,
- parsed_uri.host.decode("ascii"),
+ self._tls_options = tls_client_options_factory.get_options(
+ self._parsed_uri.host.decode("ascii")
)
- return _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=target_host,
- target_port=port,
- )
+ self._srv_resolver = srv_resolver
+ @defer.inlineCallbacks
+ def connect(self, protocol_factory):
+ """Implements IStreamClientEndpoint interface
+ """
-@implementer(IStreamClientEndpoint)
-class LoggingHostnameEndpoint(object):
- """A wrapper for HostnameEndpint which logs when it connects"""
+ first_exception = None
- def __init__(self, reactor, host, port, *args, **kwargs):
- self.host = host
- self.port = port
- self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+ server_list = yield self._resolve_server()
- def connect(self, protocol_factory):
- logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
- return self.ep.connect(protocol_factory)
+ for server in server_list:
+ host = server.host
+ port = server.port
+ try:
+ logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+ endpoint = HostnameEndpoint(self._reactor, host, port)
+ if self._tls_options:
+ endpoint = wrapClientTLS(self._tls_options, endpoint)
+ result = yield make_deferred_yieldable(
+ endpoint.connect(protocol_factory)
+ )
-@attr.s
-class _RoutingResult(object):
- """The result returned by `_route_matrix_uri`.
+ return result
+ except Exception as e:
+ logger.info(
+ "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e
+ )
+ if not first_exception:
+ first_exception = e
- Contains the parameters needed to direct a federation connection to a particular
- server.
+ # We return the first failure because that's probably the most interesting.
+ if first_exception:
+ raise first_exception
- Where a SRV record points to several servers, this object contains a single server
- chosen from the list.
- """
+ # This shouldn't happen as we should always have at least one host/port
+ # to try and if that doesn't work then we'll have an exception.
+ raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
- host_header = attr.ib()
- """
- The value we should assign to the Host header (host:port from the matrix
- URI, or .well-known).
+ @defer.inlineCallbacks
+ def _resolve_server(self):
+ """Resolves the server name to a list of hosts and ports to attempt to
+ connect to.
- :type: bytes
- """
+ Returns:
+ Deferred[list[Server]]
+ """
- tls_server_name = attr.ib()
- """
- The server name we should set in the SNI (typically host, without port, from the
- matrix URI or .well-known)
+ if self._parsed_uri.scheme != b"matrix":
+ return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)]
- :type: bytes
- """
+ # Note: We don't do well-known lookup as that needs to have happened
+ # before now, due to needing to rewrite the Host header of the HTTP
+ # request.
- target_host = attr.ib()
- """
- The hostname (or IP literal) we should route the TCP connection to (the target of the
- SRV record, or the hostname from the URL/.well-known)
+ parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes())
- :type: bytes
- """
+ host = parsed_uri.hostname
+ port = parsed_uri.port
- target_port = attr.ib()
- """
- The port we should route the TCP connection to (the target of the SRV record, or
- the port from the URL/.well-known, or 8448)
+ # If there is an explicit port or the host is an IP address we bypass
+ # SRV lookups and just use the given host/port.
+ if port or _is_ip_literal(host):
+ return [Server(host, port or 8448)]
- :type: int
+ server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
+
+ if server_list:
+ return server_list
+
+ # No SRV records, so we fallback to host and 8448
+ return [Server(host, 8448)]
+
+
+def _is_ip_literal(host):
+ """Test if the given host name is either an IPv4 or IPv6 literal.
+
+ Args:
+ host (bytes)
+
+ Returns:
+ bool
"""
+
+ host = host.decode("ascii")
+
+ try:
+ IPAddress(host)
+ return True
+ except AddrFormatError:
+ return False
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b32188766d..bbda0a23f4 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-@attr.s
+@attr.s(slots=True, frozen=True)
class Server(object):
"""
Our record of an individual server which can be tried to reach a destination.
@@ -83,6 +83,35 @@ def pick_server_from_list(server_list):
raise RuntimeError("pick_server_from_list got to end of eligible server list.")
+def _sort_server_list(server_list):
+ """Given a list of SRV records sort them into priority order and shuffle
+ each priority with the given weight.
+ """
+ priority_map = {}
+
+ for server in server_list:
+ priority_map.setdefault(server.priority, []).append(server)
+
+ results = []
+ for priority in sorted(priority_map):
+ servers = priority_map.pop(priority)
+
+ while servers:
+ total_weight = sum(s.weight for s in servers)
+ target_weight = random.randint(0, total_weight)
+
+ for s in servers:
+ target_weight -= s.weight
+
+ if target_weight <= 0:
+ break
+
+ results.append(s)
+ servers.remove(s)
+
+ return results
+
+
class SrvResolver(object):
"""Interface to the dns client to do SRV lookups, with result caching.
@@ -120,7 +149,7 @@ class SrvResolver(object):
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
- return servers
+ return _sort_server_list(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
@@ -169,4 +198,4 @@ class SrvResolver(object):
)
self._cache[service_name] = list(servers)
- return servers
+ return _sort_server_list(servers)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 2c568788b3..f97c8a59f6 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -41,9 +41,9 @@ from synapse.http.federation.well_known_resolver import (
from synapse.logging.context import LoggingContext
from synapse.util.caches.ttlcache import TTLCache
+from tests import unittest
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.unittest import TestCase
from tests.utils import default_config
logger = logging.getLogger(__name__)
@@ -67,7 +67,8 @@ def get_connection_factory():
return test_server_connection_factory
-class MatrixFederationAgentTests(TestCase):
+@unittest.DEBUG
+class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
@@ -1056,8 +1057,64 @@ class MatrixFederationAgentTests(TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, None)
+ def test_srv_fallbacks(self):
+ """Test that other SRV results are tried if the first one fails.
+ """
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ]
+ self.reactor.lookups["target.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv"
+ )
+
+ # We should see an attempt to connect to the first server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # Fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a 300ms delay in HostnameEndpoint
+ self.reactor.pump((0.4,))
+
+ # Hasn't failed yet
+ self.assertNoResult(test_d)
+
+ # We shouldnow see an attempt to connect to the second server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8444)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
-class TestCachePeriodFromHeaders(TestCase):
+class TestCachePeriodFromHeaders(unittest.TestCase):
def test_cache_control(self):
# uppercase
self.assertEqual(
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 3b885ef64b..df034ab237 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
- entry = Mock(spec_set=["expires"])
+ entry = Mock(spec_set=["expires", "priority", "weight"])
entry.expires = 0
+ entry.priority = 0
+ entry.weight = 0
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
- entry = Mock(spec_set=["expires"])
+ entry = Mock(spec_set=["expires", "priority", "weight"])
entry.expires = 999999999
+ entry.priority = 0
+ entry.weight = 0
cache = {service_name: [entry]}
resolver = SrvResolver(
|