summary refs log tree commit diff
path: root/synapse/http/federation/matrix_federation_agent.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/federation/matrix_federation_agent.py')
-rw-r--r--synapse/http/federation/matrix_federation_agent.py114
1 files changed, 110 insertions, 4 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 4a6f634c8b..07c72c9351 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,6 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import cgi
+import json
 import logging
 
 import attr
@@ -20,7 +22,7 @@ from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.client import URI, Agent, HTTPConnectionPool, readBody
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent
 
@@ -43,13 +45,19 @@ class MatrixFederationAgent(object):
         tls_client_options_factory (ClientTLSOptionsFactory|None):
             factory to use for fetching client tls options, or none to disable TLS.
 
+        _well_known_tls_policy (IPolicyForHTTPS|None):
+            TLS policy to use for fetching .well-known files. None to use a default
+            (browser-like) implementation.
+
         srv_resolver (SrvResolver|None):
             SRVResolver impl to use for looking up SRV records. None to use a default
             implementation.
     """
 
     def __init__(
-        self, reactor, tls_client_options_factory, _srv_resolver=None,
+        self, reactor, tls_client_options_factory,
+        _well_known_tls_policy=None,
+        _srv_resolver=None,
     ):
         self._reactor = reactor
         self._tls_client_options_factory = tls_client_options_factory
@@ -62,6 +70,14 @@ class MatrixFederationAgent(object):
         self._pool.maxPersistentPerHost = 5
         self._pool.cachedConnectionTimeout = 2 * 60
 
+        agent_args = {}
+        if _well_known_tls_policy is not None:
+            # the param is called 'contextFactory', but actually passing a
+            # contextfactory is deprecated, and it expects an IPolicyForHTTPS.
+            agent_args['contextFactory'] = _well_known_tls_policy
+        _well_known_agent = Agent(self._reactor, pool=self._pool, **agent_args)
+        self._well_known_agent = _well_known_agent
+
     @defer.inlineCallbacks
     def request(self, method, uri, headers=None, bodyProducer=None):
         """
@@ -114,7 +130,11 @@ class MatrixFederationAgent(object):
         class EndpointFactory(object):
             @staticmethod
             def endpointForURI(_uri):
-                logger.info("Connecting to %s:%s", res.target_host, res.target_port)
+                logger.info(
+                    "Connecting to %s:%i",
+                    res.target_host.decode("ascii"),
+                    res.target_port,
+                )
                 ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
                 if tls_options is not None:
                     ep = wrapClientTLS(tls_options, ep)
@@ -127,7 +147,7 @@ class MatrixFederationAgent(object):
         defer.returnValue(res)
 
     @defer.inlineCallbacks
-    def _route_matrix_uri(self, parsed_uri):
+    def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
         """Helper for `request`: determine the routing for a Matrix URI
 
         Args:
@@ -135,6 +155,9 @@ class MatrixFederationAgent(object):
                 parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
                 if there is no explicit port given.
 
+            lookup_well_known (bool): True if we should look up the .well-known file if
+                there is no SRV record.
+
         Returns:
             Deferred[_RoutingResult]
         """
@@ -169,6 +192,42 @@ class MatrixFederationAgent(object):
         service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
         server_list = yield self._srv_resolver.resolve_service(service_name)
 
+        if not server_list and lookup_well_known:
+            # try a .well-known lookup
+            well_known_server = yield self._get_well_known(parsed_uri.host)
+
+            if well_known_server:
+                # if we found a .well-known, start again, but don't do another
+                # .well-known lookup.
+
+                # parse the server name in the .well-known response into host/port.
+                # (This code is lifted from twisted.web.client.URI.fromBytes).
+                if b':' in well_known_server:
+                    well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
+                    try:
+                        well_known_port = int(well_known_port)
+                    except ValueError:
+                        # the part after the colon could not be parsed as an int
+                        # - we assume it is an IPv6 literal with no port (the closing
+                        # ']' stops it being parsed as an int)
+                        well_known_host, well_known_port = well_known_server, -1
+                else:
+                    well_known_host, well_known_port = well_known_server, -1
+
+                new_uri = URI(
+                    scheme=parsed_uri.scheme,
+                    netloc=well_known_server,
+                    host=well_known_host,
+                    port=well_known_port,
+                    path=parsed_uri.path,
+                    params=parsed_uri.params,
+                    query=parsed_uri.query,
+                    fragment=parsed_uri.fragment,
+                )
+
+                res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
+                defer.returnValue(res)
+
         if not server_list:
             target_host = parsed_uri.host
             port = 8448
@@ -190,6 +249,53 @@ class MatrixFederationAgent(object):
             target_port=port,
         ))
 
+    @defer.inlineCallbacks
+    def _get_well_known(self, server_name):
+        """Attempt to fetch and parse a .well-known file for the given server
+
+        Args:
+            server_name (bytes): name of the server, from the requested url
+
+        Returns:
+            Deferred[bytes|None]: either the new server name, from the .well-known, or
+                None if there was no .well-known file.
+        """
+        # FIXME: add a cache
+
+        uri = b"https://%s/.well-known/matrix/server" % (server_name, )
+        logger.info("Fetching %s", uri.decode("ascii"))
+        try:
+            response = yield make_deferred_yieldable(
+                self._well_known_agent.request(b"GET", uri),
+            )
+        except Exception as e:
+            logger.info(
+                "Connection error fetching %s: %s",
+                uri.decode("ascii"), e,
+            )
+            defer.returnValue(None)
+
+        body = yield make_deferred_yieldable(readBody(response))
+
+        if response.code != 200:
+            logger.info(
+                "Error response %i from %s: %s",
+                response.code, uri.decode("ascii"), body,
+            )
+            defer.returnValue(None)
+
+        content_types = response.headers.getRawHeaders(u'content-type')
+        if content_types is None:
+            raise Exception("no content-type header on .well-known response")
+        content_type, _opts = cgi.parse_header(content_types[-1])
+        if content_type != 'application/json':
+            raise Exception("content-type not application/json on .well-known response")
+        parsed_body = json.loads(body.decode('utf-8'))
+        logger.info("Response from .well-known: %s", parsed_body)
+        if not isinstance(parsed_body, dict) or "m.server" not in parsed_body:
+            raise Exception("invalid .well-known response")
+        defer.returnValue(parsed_body["m.server"].encode("ascii"))
+
 
 @attr.s
 class _RoutingResult(object):