summary refs log tree commit diff
path: root/synapse/http/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/federation')
-rw-r--r--synapse/http/federation/matrix_federation_agent.py386
-rw-r--r--synapse/http/federation/srv_resolver.py61
-rw-r--r--synapse/http/federation/well_known_resolver.py188
3 files changed, 380 insertions, 255 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 71a15f434d..647d26dc56 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,21 +14,21 @@
 # limitations under the License.
 
 import logging
+import urllib
 
-import attr
-from netaddr import IPAddress
+from netaddr import AddrFormatError, IPAddress
 from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.client import Agent, HTTPConnectionPool
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent
+from twisted.web.iweb import IAgent, IAgentEndpointFactory
 
-from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import WellKnownResolver
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -36,8 +36,9 @@ logger = logging.getLogger(__name__)
 
 @implementer(IAgent)
 class MatrixFederationAgent(object):
-    """An Agent-like thing which provides a `request` method which will look up a matrix
-    server and send an HTTP request to it.
+    """An Agent-like thing which provides a `request` method which correctly
+    handles resolving matrix server names when using matrix://. Handles standard
+    https URIs as normal.
 
     Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
 
@@ -51,9 +52,9 @@ class MatrixFederationAgent(object):
             SRVResolver impl to use for looking up SRV records. None to use a default
             implementation.
 
-        _well_known_cache (TTLCache|None):
-            TTLCache impl for storing cached well-known lookups. None to use a default
-            implementation.
+        _well_known_resolver (WellKnownResolver|None):
+            WellKnownResolver to use to perform well-known lookups. None to use a
+            default implementation.
     """
 
     def __init__(
@@ -61,49 +62,49 @@ class MatrixFederationAgent(object):
         reactor,
         tls_client_options_factory,
         _srv_resolver=None,
-        _well_known_cache=None,
+        _well_known_resolver=None,
     ):
         self._reactor = reactor
         self._clock = Clock(reactor)
-
-        self._tls_client_options_factory = tls_client_options_factory
-        if _srv_resolver is None:
-            _srv_resolver = SrvResolver()
-        self._srv_resolver = _srv_resolver
-
         self._pool = HTTPConnectionPool(reactor)
         self._pool.retryAutomatically = False
         self._pool.maxPersistentPerHost = 5
         self._pool.cachedConnectionTimeout = 2 * 60
 
-        self._well_known_resolver = WellKnownResolver(
+        self._agent = Agent.usingEndpointFactory(
             self._reactor,
-            agent=Agent(
-                self._reactor,
-                pool=self._pool,
-                contextFactory=tls_client_options_factory,
+            MatrixHostnameEndpointFactory(
+                reactor, tls_client_options_factory, _srv_resolver
             ),
-            well_known_cache=_well_known_cache,
+            pool=self._pool,
         )
 
+        if _well_known_resolver is None:
+            _well_known_resolver = WellKnownResolver(
+                self._reactor,
+                agent=Agent(
+                    self._reactor,
+                    pool=self._pool,
+                    contextFactory=tls_client_options_factory,
+                ),
+            )
+
+        self._well_known_resolver = _well_known_resolver
+
     @defer.inlineCallbacks
     def request(self, method, uri, headers=None, bodyProducer=None):
         """
         Args:
             method (bytes): HTTP method: GET/POST/etc
-
             uri (bytes): Absolute URI to be retrieved
-
             headers (twisted.web.http_headers.Headers|None):
                 HTTP headers to send with the request, or None to
                 send no extra headers.
-
             bodyProducer (twisted.web.iweb.IBodyProducer|None):
                 An object which can generate bytes to make up the
                 body of this request (for example, the properly encoded contents of
                 a file for a file upload).  Or None if the request is to have
                 no body.
-
         Returns:
             Deferred[twisted.web.iweb.IResponse]:
                 fires when the header of the response has been received (regardless of the
@@ -111,210 +112,207 @@ class MatrixFederationAgent(object):
                 response from being received (including problems that prevent the request
                 from being sent).
         """
-        parsed_uri = URI.fromBytes(uri, defaultPort=-1)
-        res = yield self._route_matrix_uri(parsed_uri)
+        # We use urlparse as that will set `port` to None if there is no
+        # explicit port.
+        parsed_uri = urllib.parse.urlparse(uri)
 
-        # set up the TLS connection params
+        # If this is a matrix:// URI check if the server has delegated matrix
+        # traffic using well-known delegation.
         #
-        # XXX disabling TLS is really only supported here for the benefit of the
-        # unit tests. We should make the UTs cope with TLS rather than having to make
-        # the code support the unit tests.
-        if self._tls_client_options_factory is None:
-            tls_options = None
-        else:
-            tls_options = self._tls_client_options_factory.get_options(
-                res.tls_server_name.decode("ascii")
+        # We have to do this here and not in the endpoint as we need to rewrite
+        # the host header with the delegated server name.
+        delegated_server = None
+        if (
+            parsed_uri.scheme == b"matrix"
+            and not _is_ip_literal(parsed_uri.hostname)
+            and not parsed_uri.port
+        ):
+            well_known_result = yield self._well_known_resolver.get_well_known(
+                parsed_uri.hostname
+            )
+            delegated_server = well_known_result.delegated_server
+
+        if delegated_server:
+            # Ok, the server has delegated matrix traffic to somewhere else, so
+            # lets rewrite the URL to replace the server with the delegated
+            # server name.
+            uri = urllib.parse.urlunparse(
+                (
+                    parsed_uri.scheme,
+                    delegated_server,
+                    parsed_uri.path,
+                    parsed_uri.params,
+                    parsed_uri.query,
+                    parsed_uri.fragment,
+                )
             )
+            parsed_uri = urllib.parse.urlparse(uri)
 
-        # make sure that the Host header is set correctly
+        # We need to make sure the host header is set to the netloc of the
+        # server.
         if headers is None:
             headers = Headers()
         else:
             headers = headers.copy()
 
         if not headers.hasHeader(b"host"):
-            headers.addRawHeader(b"host", res.host_header)
+            headers.addRawHeader(b"host", parsed_uri.netloc)
 
-        class EndpointFactory(object):
-            @staticmethod
-            def endpointForURI(_uri):
-                ep = LoggingHostnameEndpoint(
-                    self._reactor, res.target_host, res.target_port
-                )
-                if tls_options is not None:
-                    ep = wrapClientTLS(tls_options, ep)
-                return ep
-
-        agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
         res = yield make_deferred_yieldable(
-            agent.request(method, uri, headers, bodyProducer)
+            self._agent.request(method, uri, headers, bodyProducer)
         )
+
         return res
 
-    @defer.inlineCallbacks
-    def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
-        """Helper for `request`: determine the routing for a Matrix URI
 
-        Args:
-            parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
-                parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
-                if there is no explicit port given.
+@implementer(IAgentEndpointFactory)
+class MatrixHostnameEndpointFactory(object):
+    """Factory for MatrixHostnameEndpoint for parsing to an Agent.
+    """
 
-            lookup_well_known (bool): True if we should look up the .well-known file if
-                there is no SRV record.
+    def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+        self._reactor = reactor
+        self._tls_client_options_factory = tls_client_options_factory
 
-        Returns:
-            Deferred[_RoutingResult]
-        """
-        # check for an IP literal
-        try:
-            ip_address = IPAddress(parsed_uri.host.decode("ascii"))
-        except Exception:
-            # not an IP address
-            ip_address = None
-
-        if ip_address:
-            port = parsed_uri.port
-            if port == -1:
-                port = 8448
-            return _RoutingResult(
-                host_header=parsed_uri.netloc,
-                tls_server_name=parsed_uri.host,
-                target_host=parsed_uri.host,
-                target_port=port,
-            )
+        if srv_resolver is None:
+            srv_resolver = SrvResolver()
 
-        if parsed_uri.port != -1:
-            # there is an explicit port
-            return _RoutingResult(
-                host_header=parsed_uri.netloc,
-                tls_server_name=parsed_uri.host,
-                target_host=parsed_uri.host,
-                target_port=parsed_uri.port,
-            )
+        self._srv_resolver = srv_resolver
 
-        if lookup_well_known:
-            # try a .well-known lookup
-            well_known_result = yield self._well_known_resolver.get_well_known(
-                parsed_uri.host
-            )
-            well_known_server = well_known_result.delegated_server
-
-            if well_known_server:
-                # if we found a .well-known, start again, but don't do another
-                # .well-known lookup.
-
-                # parse the server name in the .well-known response into host/port.
-                # (This code is lifted from twisted.web.client.URI.fromBytes).
-                if b":" in well_known_server:
-                    well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
-                    try:
-                        well_known_port = int(well_known_port)
-                    except ValueError:
-                        # the part after the colon could not be parsed as an int
-                        # - we assume it is an IPv6 literal with no port (the closing
-                        # ']' stops it being parsed as an int)
-                        well_known_host, well_known_port = well_known_server, -1
-                else:
-                    well_known_host, well_known_port = well_known_server, -1
-
-                new_uri = URI(
-                    scheme=parsed_uri.scheme,
-                    netloc=well_known_server,
-                    host=well_known_host,
-                    port=well_known_port,
-                    path=parsed_uri.path,
-                    params=parsed_uri.params,
-                    query=parsed_uri.query,
-                    fragment=parsed_uri.fragment,
-                )
+    def endpointForURI(self, parsed_uri):
+        return MatrixHostnameEndpoint(
+            self._reactor,
+            self._tls_client_options_factory,
+            self._srv_resolver,
+            parsed_uri,
+        )
 
-                res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
-                return res
-
-        # try a SRV lookup
-        service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
-        server_list = yield self._srv_resolver.resolve_service(service_name)
-
-        if not server_list:
-            target_host = parsed_uri.host
-            port = 8448
-            logger.debug(
-                "No SRV record for %s, using %s:%i",
-                parsed_uri.host.decode("ascii"),
-                target_host.decode("ascii"),
-                port,
-            )
+
+@implementer(IStreamClientEndpoint)
+class MatrixHostnameEndpoint(object):
+    """An endpoint that resolves matrix:// URLs using Matrix server name
+    resolution (i.e. via SRV). Does not check for well-known delegation.
+
+    Args:
+        reactor (IReactor)
+        tls_client_options_factory (ClientTLSOptionsFactory|None):
+            factory to use for fetching client tls options, or none to disable TLS.
+        srv_resolver (SrvResolver): The SRV resolver to use
+        parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
+            to connect to.
+    """
+
+    def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+        self._reactor = reactor
+
+        self._parsed_uri = parsed_uri
+
+        # set up the TLS connection params
+        #
+        # XXX disabling TLS is really only supported here for the benefit of the
+        # unit tests. We should make the UTs cope with TLS rather than having to make
+        # the code support the unit tests.
+
+        if tls_client_options_factory is None:
+            self._tls_options = None
         else:
-            target_host, port = pick_server_from_list(server_list)
-            logger.debug(
-                "Picked %s:%i from SRV records for %s",
-                target_host.decode("ascii"),
-                port,
-                parsed_uri.host.decode("ascii"),
+            self._tls_options = tls_client_options_factory.get_options(
+                self._parsed_uri.host
             )
 
-        return _RoutingResult(
-            host_header=parsed_uri.netloc,
-            tls_server_name=parsed_uri.host,
-            target_host=target_host,
-            target_port=port,
-        )
+        self._srv_resolver = srv_resolver
 
+    def connect(self, protocol_factory):
+        """Implements IStreamClientEndpoint interface
+        """
 
-@implementer(IStreamClientEndpoint)
-class LoggingHostnameEndpoint(object):
-    """A wrapper for HostnameEndpint which logs when it connects"""
+        return run_in_background(self._do_connect, protocol_factory)
 
-    def __init__(self, reactor, host, port, *args, **kwargs):
-        self.host = host
-        self.port = port
-        self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+    @defer.inlineCallbacks
+    def _do_connect(self, protocol_factory):
+        first_exception = None
+
+        server_list = yield self._resolve_server()
+
+        for server in server_list:
+            host = server.host
+            port = server.port
+
+            try:
+                logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+                endpoint = HostnameEndpoint(self._reactor, host, port)
+                if self._tls_options:
+                    endpoint = wrapClientTLS(self._tls_options, endpoint)
+                result = yield make_deferred_yieldable(
+                    endpoint.connect(protocol_factory)
+                )
 
-    def connect(self, protocol_factory):
-        logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
-        return self.ep.connect(protocol_factory)
+                return result
+            except Exception as e:
+                logger.info(
+                    "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e
+                )
+                if not first_exception:
+                    first_exception = e
 
+        # We return the first failure because that's probably the most interesting.
+        if first_exception:
+            raise first_exception
 
-@attr.s
-class _RoutingResult(object):
-    """The result returned by `_route_matrix_uri`.
+        # This shouldn't happen as we should always have at least one host/port
+        # to try and if that doesn't work then we'll have an exception.
+        raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
 
-    Contains the parameters needed to direct a federation connection to a particular
-    server.
+    @defer.inlineCallbacks
+    def _resolve_server(self):
+        """Resolves the server name to a list of hosts and ports to attempt to
+        connect to.
 
-    Where a SRV record points to several servers, this object contains a single server
-    chosen from the list.
-    """
+        Returns:
+            Deferred[list[Server]]
+        """
 
-    host_header = attr.ib()
-    """
-    The value we should assign to the Host header (host:port from the matrix
-    URI, or .well-known).
+        if self._parsed_uri.scheme != b"matrix":
+            return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)]
 
-    :type: bytes
-    """
+        # Note: We don't do well-known lookup as that needs to have happened
+        # before now, due to needing to rewrite the Host header of the HTTP
+        # request.
 
-    tls_server_name = attr.ib()
-    """
-    The server name we should set in the SNI (typically host, without port, from the
-    matrix URI or .well-known)
+        # We reparse the URI so that defaultPort is -1 rather than 80
+        parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes())
 
-    :type: bytes
-    """
+        host = parsed_uri.hostname
+        port = parsed_uri.port
 
-    target_host = attr.ib()
-    """
-    The hostname (or IP literal) we should route the TCP connection to (the target of the
-    SRV record, or the hostname from the URL/.well-known)
+        # If there is an explicit port or the host is an IP address we bypass
+        # SRV lookups and just use the given host/port.
+        if port or _is_ip_literal(host):
+            return [Server(host, port or 8448)]
 
-    :type: bytes
-    """
+        server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
+
+        if server_list:
+            return server_list
+
+        # No SRV records, so we fallback to host and 8448
+        return [Server(host, 8448)]
 
-    target_port = attr.ib()
-    """
-    The port we should route the TCP connection to (the target of the SRV record, or
-    the port from the URL/.well-known, or 8448)
 
-    :type: int
+def _is_ip_literal(host):
+    """Test if the given host name is either an IPv4 or IPv6 literal.
+
+    Args:
+        host (bytes)
+
+    Returns:
+        bool
     """
+
+    host = host.decode("ascii")
+
+    try:
+        IPAddress(host)
+        return True
+    except AddrFormatError:
+        return False
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b32188766d..3fe4ffb9e5 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
 SERVER_CACHE = {}
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class Server(object):
     """
     Our record of an individual server which can be tried to reach a destination.
@@ -53,34 +53,47 @@ class Server(object):
     expires = attr.ib(default=0)
 
 
-def pick_server_from_list(server_list):
-    """Randomly choose a server from the server list
+def _sort_server_list(server_list):
+    """Given a list of SRV records sort them into priority order and shuffle
+    each priority with the given weight.
+    """
+    priority_map = {}
 
-    Args:
-        server_list (list[Server]): list of candidate servers
+    for server in server_list:
+        priority_map.setdefault(server.priority, []).append(server)
 
-    Returns:
-        Tuple[bytes, int]: (host, port) pair for the chosen server
-    """
-    if not server_list:
-        raise RuntimeError("pick_server_from_list called with empty list")
+    results = []
+    for priority in sorted(priority_map):
+        servers = priority_map[priority]
+
+        # This algorithms roughly follows the algorithm described in RFC2782,
+        # changed to remove an off-by-one error.
+        #
+        # N.B. Weights can be zero, which means that they should be picked
+        # rarely.
+
+        total_weight = sum(s.weight for s in servers)
+
+        # Total weight can become zero if there are only zero weight servers
+        # left, which we handle by just shuffling and appending to the results.
+        while servers and total_weight:
+            target_weight = random.randint(1, total_weight)
 
-    # TODO: currently we only use the lowest-priority servers. We should maintain a
-    # cache of servers known to be "down" and filter them out
+            for s in servers:
+                target_weight -= s.weight
 
-    min_priority = min(s.priority for s in server_list)
-    eligible_servers = list(s for s in server_list if s.priority == min_priority)
-    total_weight = sum(s.weight for s in eligible_servers)
-    target_weight = random.randint(0, total_weight)
+                if target_weight <= 0:
+                    break
 
-    for s in eligible_servers:
-        target_weight -= s.weight
+            results.append(s)
+            servers.remove(s)
+            total_weight -= s.weight
 
-        if target_weight <= 0:
-            return s.host, s.port
+        if servers:
+            random.shuffle(servers)
+            results.extend(servers)
 
-    # this should be impossible.
-    raise RuntimeError("pick_server_from_list got to end of eligible server list.")
+    return results
 
 
 class SrvResolver(object):
@@ -120,7 +133,7 @@ class SrvResolver(object):
         if cache_entry:
             if all(s.expires > now for s in cache_entry):
                 servers = list(cache_entry)
-                return servers
+                return _sort_server_list(servers)
 
         try:
             answers, _, _ = yield make_deferred_yieldable(
@@ -169,4 +182,4 @@ class SrvResolver(object):
             )
 
         self._cache[service_name] = list(servers)
-        return servers
+        return _sort_server_list(servers)
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index d2866ff67d..7ddfad286d 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -32,22 +32,40 @@ from synapse.util.metrics import Measure
 # period to cache .well-known results for by default
 WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
 
-# jitter to add to the .well-known default cache ttl
-WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
+# jitter factor to add to the .well-known default cache ttls
+WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 0.1
 
 # period to cache failure to fetch .well-known for
 WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
 
+# period to cache failure to fetch .well-known if there has recently been a
+# valid well-known for that domain.
+WELL_KNOWN_DOWN_CACHE_PERIOD = 2 * 60
+
+# period to remember there was a valid well-known after valid record expires
+WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID = 2 * 3600
+
 # cap for .well-known cache period
 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
 
 # lower bound for .well-known cache period
 WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
 
+# Attempt to refetch a cached well-known N% of the TTL before it expires.
+# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then
+# we'll start trying to refetch 1 minute before it expires.
+WELL_KNOWN_GRACE_PERIOD_FACTOR = 0.2
+
+# Number of times we retry fetching a well-known for a domain we know recently
+# had a valid entry.
+WELL_KNOWN_RETRY_ATTEMPTS = 3
+
+
 logger = logging.getLogger(__name__)
 
 
 _well_known_cache = TTLCache("well-known")
+_had_valid_well_known_cache = TTLCache("had-valid-well-known")
 
 
 @attr.s(slots=True, frozen=True)
@@ -59,14 +77,20 @@ class WellKnownResolver(object):
     """Handles well-known lookups for matrix servers.
     """
 
-    def __init__(self, reactor, agent, well_known_cache=None):
+    def __init__(
+        self, reactor, agent, well_known_cache=None, had_well_known_cache=None
+    ):
         self._reactor = reactor
         self._clock = Clock(reactor)
 
         if well_known_cache is None:
             well_known_cache = _well_known_cache
 
+        if had_well_known_cache is None:
+            had_well_known_cache = _had_valid_well_known_cache
+
         self._well_known_cache = well_known_cache
+        self._had_valid_well_known_cache = had_well_known_cache
         self._well_known_agent = RedirectAgent(agent)
 
     @defer.inlineCallbacks
@@ -80,59 +104,86 @@ class WellKnownResolver(object):
             Deferred[WellKnownLookupResult]: The result of the lookup
         """
         try:
-            result = self._well_known_cache[server_name]
+            prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
+                server_name
+            )
+
+            now = self._clock.time()
+            if now < expiry - WELL_KNOWN_GRACE_PERIOD_FACTOR * ttl:
+                return WellKnownLookupResult(delegated_server=prev_result)
         except KeyError:
-            # TODO: should we linearise so that we don't end up doing two .well-known
-            # requests for the same server in parallel?
+            prev_result = None
+
+        # TODO: should we linearise so that we don't end up doing two .well-known
+        # requests for the same server in parallel?
+        try:
             with Measure(self._clock, "get_well_known"):
-                result, cache_period = yield self._do_get_well_known(server_name)
+                result, cache_period = yield self._fetch_well_known(server_name)
+
+        except _FetchWellKnownFailure as e:
+            if prev_result and e.temporary:
+                # This is a temporary failure and we have a still valid cached
+                # result, so lets return that. Hopefully the next time we ask
+                # the remote will be back up again.
+                return WellKnownLookupResult(delegated_server=prev_result)
+
+            result = None
+
+            if self._had_valid_well_known_cache.get(server_name, False):
+                # We have recently seen a valid well-known record for this
+                # server, so we cache the lack of well-known for a shorter time.
+                cache_period = WELL_KNOWN_DOWN_CACHE_PERIOD
+            else:
+                cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
+
+            # add some randomness to the TTL to avoid a stampeding herd
+            cache_period *= random.uniform(
+                1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+                1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+            )
 
-            if cache_period > 0:
-                self._well_known_cache.set(server_name, result, cache_period)
+        if cache_period > 0:
+            self._well_known_cache.set(server_name, result, cache_period)
 
         return WellKnownLookupResult(delegated_server=result)
 
     @defer.inlineCallbacks
-    def _do_get_well_known(self, server_name):
+    def _fetch_well_known(self, server_name):
         """Actually fetch and parse a .well-known, without checking the cache
 
         Args:
             server_name (bytes): name of the server, from the requested url
 
+        Raises:
+            _FetchWellKnownFailure if we fail to lookup a result
+
         Returns:
-            Deferred[Tuple[bytes|None|object],int]:
-                result, cache period, where result is one of:
-                 - the new server name from the .well-known (as a `bytes`)
-                 - None if there was no .well-known file.
-                 - INVALID_WELL_KNOWN if the .well-known was invalid
+            Deferred[Tuple[bytes,int]]: The lookup result and cache period.
         """
-        uri = b"https://%s/.well-known/matrix/server" % (server_name,)
-        uri_str = uri.decode("ascii")
-        logger.info("Fetching %s", uri_str)
+
+        had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
+
+        # We do this in two steps to differentiate between possibly transient
+        # errors (e.g. can't connect to host, 503 response) and more permenant
+        # errors (such as getting a 404 response).
+        response, body = yield self._make_well_known_request(
+            server_name, retry=had_valid_well_known
+        )
+
         try:
-            response = yield make_deferred_yieldable(
-                self._well_known_agent.request(b"GET", uri)
-            )
-            body = yield make_deferred_yieldable(readBody(response))
             if response.code != 200:
                 raise Exception("Non-200 response %s" % (response.code,))
 
             parsed_body = json.loads(body.decode("utf-8"))
             logger.info("Response from .well-known: %s", parsed_body)
-            if not isinstance(parsed_body, dict):
-                raise Exception("not a dict")
-            if "m.server" not in parsed_body:
-                raise Exception("Missing key 'm.server'")
-        except Exception as e:
-            logger.info("Error fetching %s: %s", uri_str, e)
-
-            # add some randomness to the TTL to avoid a stampeding herd every hour
-            # after startup
-            cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
-            cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
-            return (None, cache_period)
 
-        result = parsed_body["m.server"].encode("ascii")
+            result = parsed_body["m.server"].encode("ascii")
+        except defer.CancelledError:
+            # Bail if we've been cancelled
+            raise
+        except Exception as e:
+            logger.info("Error parsing well-known for %s: %s", server_name, e)
+            raise _FetchWellKnownFailure(temporary=False)
 
         cache_period = _cache_period_from_headers(
             response.headers, time_now=self._reactor.seconds
@@ -141,12 +192,68 @@ class WellKnownResolver(object):
             cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
             # add some randomness to the TTL to avoid a stampeding herd every 24 hours
             # after startup
-            cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+            cache_period *= random.uniform(
+                1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+                1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+            )
         else:
             cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
             cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD)
 
-        return (result, cache_period)
+        # We got a success, mark as such in the cache
+        self._had_valid_well_known_cache.set(
+            server_name,
+            bool(result),
+            cache_period + WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID,
+        )
+
+        return result, cache_period
+
+    @defer.inlineCallbacks
+    def _make_well_known_request(self, server_name, retry):
+        """Make the well known request.
+
+        This will retry the request if requested and it fails (with unable
+        to connect or receives a 5xx error).
+
+        Args:
+            server_name (bytes)
+            retry (bool): Whether to retry the request if it fails.
+
+        Returns:
+            Deferred[tuple[IResponse, bytes]] Returns the response object and
+            body. Response may be a non-200 response.
+        """
+        uri = b"https://%s/.well-known/matrix/server" % (server_name,)
+        uri_str = uri.decode("ascii")
+
+        i = 0
+        while True:
+            i += 1
+
+            logger.info("Fetching %s", uri_str)
+            try:
+                response = yield make_deferred_yieldable(
+                    self._well_known_agent.request(b"GET", uri)
+                )
+                body = yield make_deferred_yieldable(readBody(response))
+
+                if 500 <= response.code < 600:
+                    raise Exception("Non-200 response %s" % (response.code,))
+
+                return response, body
+            except defer.CancelledError:
+                # Bail if we've been cancelled
+                raise
+            except Exception as e:
+                if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS:
+                    logger.info("Error fetching %s: %s", uri_str, e)
+                    raise _FetchWellKnownFailure(temporary=True)
+
+                logger.info("Error fetching %s: %s. Retrying", uri_str, e)
+
+            # Sleep briefly in the hopes that they come back up
+            yield self._clock.sleep(0.5)
 
 
 def _cache_period_from_headers(headers, time_now=time.time):
@@ -185,3 +292,10 @@ def _parse_cache_control(headers):
             v = splits[1] if len(splits) > 1 else None
             cache_controls[k] = v
     return cache_controls
+
+
+@attr.s()
+class _FetchWellKnownFailure(Exception):
+    # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
+    # a temporary failure.
+    temporary = attr.ib()