diff --git a/changelog.d/4488.feature b/changelog.d/4488.feature
new file mode 100644
index 0000000000..bda713adf9
--- /dev/null
+++ b/changelog.d/4488.feature
@@ -0,0 +1 @@
+Implement MSC1708 (.well-known routing for server-server federation)
\ No newline at end of file
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 9526f39cca..4a6f634c8b 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,6 +14,8 @@
# limitations under the License.
import logging
+import attr
+from netaddr import IPAddress
from zope.interface import implementer
from twisted.internet import defer
@@ -85,9 +87,11 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request
from being sent).
"""
-
parsed_uri = URI.fromBytes(uri, defaultPort=-1)
+ res = yield self._route_matrix_uri(parsed_uri)
+ # set up the TLS connection params
+ #
# XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests.
@@ -95,22 +99,9 @@ class MatrixFederationAgent(object):
tls_options = None
else:
tls_options = self._tls_client_options_factory.get_options(
- parsed_uri.host.decode("ascii")
+ res.tls_server_name.decode("ascii")
)
- if parsed_uri.port != -1:
- # there was an explicit port in the URI
- target = parsed_uri.host, parsed_uri.port
- else:
- service_name = b"_matrix._tcp.%s" % (parsed_uri.host, )
- server_list = yield self._srv_resolver.resolve_service(service_name)
- if not server_list:
- target = (parsed_uri.host, 8448)
- logger.debug(
- "No SRV record for %s, using %s", service_name, target)
- else:
- target = pick_server_from_list(server_list)
-
# make sure that the Host header is set correctly
if headers is None:
headers = Headers()
@@ -118,13 +109,13 @@ class MatrixFederationAgent(object):
headers = headers.copy()
if not headers.hasHeader(b'host'):
- headers.addRawHeader(b'host', parsed_uri.netloc)
+ headers.addRawHeader(b'host', res.host_header)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
- logger.info("Connecting to %s:%s", target[0], target[1])
- ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
+ logger.info("Connecting to %s:%s", res.target_host, res.target_port)
+ ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
return ep
@@ -134,3 +125,111 @@ class MatrixFederationAgent(object):
agent.request(method, uri, headers, bodyProducer)
)
defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _route_matrix_uri(self, parsed_uri):
+ """Helper for `request`: determine the routing for a Matrix URI
+
+ Args:
+ parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
+ parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
+ if there is no explicit port given.
+
+ Returns:
+ Deferred[_RoutingResult]
+ """
+ # check for an IP literal
+ try:
+ ip_address = IPAddress(parsed_uri.host.decode("ascii"))
+ except Exception:
+ # not an IP address
+ ip_address = None
+
+ if ip_address:
+ port = parsed_uri.port
+ if port == -1:
+ port = 8448
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=port,
+ ))
+
+ if parsed_uri.port != -1:
+ # there is an explicit port
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=parsed_uri.port,
+ ))
+
+ # try a SRV lookup
+ service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
+ server_list = yield self._srv_resolver.resolve_service(service_name)
+
+ if not server_list:
+ target_host = parsed_uri.host
+ port = 8448
+ logger.debug(
+ "No SRV record for %s, using %s:%i",
+ parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
+ )
+ else:
+ target_host, port = pick_server_from_list(server_list)
+ logger.debug(
+ "Picked %s:%i from SRV records for %s",
+ target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
+ )
+
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=target_host,
+ target_port=port,
+ ))
+
+
+@attr.s
+class _RoutingResult(object):
+ """The result returned by `_route_matrix_uri`.
+
+ Contains the parameters needed to direct a federation connection to a particular
+ server.
+
+ Where a SRV record points to several servers, this object contains a single server
+ chosen from the list.
+ """
+
+ host_header = attr.ib()
+ """
+ The value we should assign to the Host header (host:port from the matrix
+ URI, or .well-known).
+
+ :type: bytes
+ """
+
+ tls_server_name = attr.ib()
+ """
+ The server name we should set in the SNI (typically host, without port, from the
+ matrix URI or .well-known)
+
+ :type: bytes
+ """
+
+ target_host = attr.ib()
+ """
+ The hostname (or IP literal) we should route the TCP connection to (the target of the
+ SRV record, or the hostname from the URL/.well-known)
+
+ :type: bytes
+ """
+
+ target_port = attr.ib()
+ """
+ The port we should route the TCP connection to (the target of the SRV record, or
+ the port from the URL/.well-known, or 8448)
+
+ :type: int
+ """
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index f144092a51..8257594fb8 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -166,11 +166,7 @@ class MatrixFederationAgentTests(TestCase):
"""
Test the behaviour when the server name contains an explicit IP (with no port)
"""
-
- # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
- self.mock_resolver.resolve_service.side_effect = lambda _: []
-
- # then there will be a getaddrinfo on the IP
+ # there will be a getaddrinfo on the IP
self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
@@ -178,10 +174,6 @@ class MatrixFederationAgentTests(TestCase):
# Nothing happened yet
self.assertNoResult(test_d)
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.1.2.3.4",
- )
-
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
@@ -215,10 +207,7 @@ class MatrixFederationAgentTests(TestCase):
(with no port)
"""
- # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
- self.mock_resolver.resolve_service.side_effect = lambda _: []
-
- # then there will be a getaddrinfo on the IP
+ # there will be a getaddrinfo on the IP
self.reactor.lookups["::1"] = "::1"
test_d = self._make_get_request(b"matrix://[::1]/foo/bar")
@@ -226,10 +215,6 @@ class MatrixFederationAgentTests(TestCase):
# Nothing happened yet
self.assertNoResult(test_d)
- self.mock_resolver.resolve_service.assert_called_once_with(
- b"_matrix._tcp.::1",
- )
-
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
|