summary refs log tree commit diff
path: root/synapse/http/federation/matrix_federation_agent.py
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2019-01-28 14:08:24 +0000
committerAndrew Morgan <andrew@amorgan.xyz>2019-01-28 14:08:24 +0000
commit4026d555fabe6f00d865a0c63226799e53a092e1 (patch)
tree9a8211924748e48915e1acf6d097e0bff88a03bb /synapse/http/federation/matrix_federation_agent.py
parentReuse predecessor method (diff)
parentRemove --process-dependency-links from UPGRADE.rst (#4485) (diff)
downloadsynapse-4026d555fabe6f00d865a0c63226799e53a092e1.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into anoa/dm_room_upgrade
Diffstat (limited to 'synapse/http/federation/matrix_federation_agent.py')
-rw-r--r--synapse/http/federation/matrix_federation_agent.py144
1 files changed, 127 insertions, 17 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 0ec28c6696..4a6f634c8b 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,14 +14,16 @@
 # limitations under the License.
 import logging
 
+import attr
+from netaddr import IPAddress
 from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent
 
-from synapse.http.endpoint import parse_server_name
 from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
 from synapse.util.logcontext import make_deferred_yieldable
 
@@ -85,35 +87,35 @@ class MatrixFederationAgent(object):
                 response from being received (including problems that prevent the request
                 from being sent).
         """
+        parsed_uri = URI.fromBytes(uri, defaultPort=-1)
+        res = yield self._route_matrix_uri(parsed_uri)
 
-        parsed_uri = URI.fromBytes(uri)
-        server_name_bytes = parsed_uri.netloc
-        host, port = parse_server_name(server_name_bytes.decode("ascii"))
-
+        # set up the TLS connection params
+        #
         # XXX disabling TLS is really only supported here for the benefit of the
         # unit tests. We should make the UTs cope with TLS rather than having to make
         # the code support the unit tests.
         if self._tls_client_options_factory is None:
             tls_options = None
         else:
-            tls_options = self._tls_client_options_factory.get_options(host)
+            tls_options = self._tls_client_options_factory.get_options(
+                res.tls_server_name.decode("ascii")
+            )
 
-        if port is not None:
-            target = (host, port)
+        # make sure that the Host header is set correctly
+        if headers is None:
+            headers = Headers()
         else:
-            service_name = b"_matrix._tcp.%s" % (server_name_bytes, )
-            server_list = yield self._srv_resolver.resolve_service(service_name)
-            if not server_list:
-                target = (host, 8448)
-                logger.debug("No SRV record for %s, using %s", host, target)
-            else:
-                target = pick_server_from_list(server_list)
+            headers = headers.copy()
+
+        if not headers.hasHeader(b'host'):
+            headers.addRawHeader(b'host', res.host_header)
 
         class EndpointFactory(object):
             @staticmethod
             def endpointForURI(_uri):
-                logger.info("Connecting to %s:%s", target[0], target[1])
-                ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
+                logger.info("Connecting to %s:%s", res.target_host, res.target_port)
+                ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
                 if tls_options is not None:
                     ep = wrapClientTLS(tls_options, ep)
                 return ep
@@ -123,3 +125,111 @@ class MatrixFederationAgent(object):
             agent.request(method, uri, headers, bodyProducer)
         )
         defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def _route_matrix_uri(self, parsed_uri):
+        """Helper for `request`: determine the routing for a Matrix URI
+
+        Args:
+            parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
+                parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
+                if there is no explicit port given.
+
+        Returns:
+            Deferred[_RoutingResult]
+        """
+        # check for an IP literal
+        try:
+            ip_address = IPAddress(parsed_uri.host.decode("ascii"))
+        except Exception:
+            # not an IP address
+            ip_address = None
+
+        if ip_address:
+            port = parsed_uri.port
+            if port == -1:
+                port = 8448
+            defer.returnValue(_RoutingResult(
+                host_header=parsed_uri.netloc,
+                tls_server_name=parsed_uri.host,
+                target_host=parsed_uri.host,
+                target_port=port,
+            ))
+
+        if parsed_uri.port != -1:
+            # there is an explicit port
+            defer.returnValue(_RoutingResult(
+                host_header=parsed_uri.netloc,
+                tls_server_name=parsed_uri.host,
+                target_host=parsed_uri.host,
+                target_port=parsed_uri.port,
+            ))
+
+        # try a SRV lookup
+        service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
+        server_list = yield self._srv_resolver.resolve_service(service_name)
+
+        if not server_list:
+            target_host = parsed_uri.host
+            port = 8448
+            logger.debug(
+                "No SRV record for %s, using %s:%i",
+                parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
+            )
+        else:
+            target_host, port = pick_server_from_list(server_list)
+            logger.debug(
+                "Picked %s:%i from SRV records for %s",
+                target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
+            )
+
+        defer.returnValue(_RoutingResult(
+            host_header=parsed_uri.netloc,
+            tls_server_name=parsed_uri.host,
+            target_host=target_host,
+            target_port=port,
+        ))
+
+
+@attr.s
+class _RoutingResult(object):
+    """The result returned by `_route_matrix_uri`.
+
+    Contains the parameters needed to direct a federation connection to a particular
+    server.
+
+    Where a SRV record points to several servers, this object contains a single server
+    chosen from the list.
+    """
+
+    host_header = attr.ib()
+    """
+    The value we should assign to the Host header (host:port from the matrix
+    URI, or .well-known).
+
+    :type: bytes
+    """
+
+    tls_server_name = attr.ib()
+    """
+    The server name we should set in the SNI (typically host, without port, from the
+    matrix URI or .well-known)
+
+    :type: bytes
+    """
+
+    target_host = attr.ib()
+    """
+    The hostname (or IP literal) we should route the TCP connection to (the target of the
+    SRV record, or the hostname from the URL/.well-known)
+
+    :type: bytes
+    """
+
+    target_port = attr.ib()
+    """
+    The port we should route the TCP connection to (the target of the SRV record, or
+    the port from the URL/.well-known, or 8448)
+
+    :type: int
+    """