diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 71a15f434d..647d26dc56 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,21 +14,21 @@
# limitations under the License.
import logging
+import urllib
-import attr
-from netaddr import IPAddress
+from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent
+from twisted.web.iweb import IAgent, IAgentEndpointFactory
-from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -36,8 +36,9 @@ logger = logging.getLogger(__name__)
@implementer(IAgent)
class MatrixFederationAgent(object):
- """An Agent-like thing which provides a `request` method which will look up a matrix
- server and send an HTTP request to it.
+ """An Agent-like thing which provides a `request` method which correctly
+ handles resolving matrix server names when using matrix://. Handles standard
+ https URIs as normal.
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
@@ -51,9 +52,9 @@ class MatrixFederationAgent(object):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
- _well_known_cache (TTLCache|None):
- TTLCache impl for storing cached well-known lookups. None to use a default
- implementation.
+ _well_known_resolver (WellKnownResolver|None):
+ WellKnownResolver to use to perform well-known lookups. None to use a
+ default implementation.
"""
def __init__(
@@ -61,49 +62,49 @@ class MatrixFederationAgent(object):
reactor,
tls_client_options_factory,
_srv_resolver=None,
- _well_known_cache=None,
+ _well_known_resolver=None,
):
self._reactor = reactor
self._clock = Clock(reactor)
-
- self._tls_client_options_factory = tls_client_options_factory
- if _srv_resolver is None:
- _srv_resolver = SrvResolver()
- self._srv_resolver = _srv_resolver
-
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
- self._well_known_resolver = WellKnownResolver(
+ self._agent = Agent.usingEndpointFactory(
self._reactor,
- agent=Agent(
- self._reactor,
- pool=self._pool,
- contextFactory=tls_client_options_factory,
+ MatrixHostnameEndpointFactory(
+ reactor, tls_client_options_factory, _srv_resolver
),
- well_known_cache=_well_known_cache,
+ pool=self._pool,
)
+ if _well_known_resolver is None:
+ _well_known_resolver = WellKnownResolver(
+ self._reactor,
+ agent=Agent(
+ self._reactor,
+ pool=self._pool,
+ contextFactory=tls_client_options_factory,
+ ),
+ )
+
+ self._well_known_resolver = _well_known_resolver
+
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
"""
Args:
method (bytes): HTTP method: GET/POST/etc
-
uri (bytes): Absolute URI to be retrieved
-
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
-
bodyProducer (twisted.web.iweb.IBodyProducer|None):
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
no body.
-
Returns:
Deferred[twisted.web.iweb.IResponse]:
fires when the header of the response has been received (regardless of the
@@ -111,210 +112,207 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request
from being sent).
"""
- parsed_uri = URI.fromBytes(uri, defaultPort=-1)
- res = yield self._route_matrix_uri(parsed_uri)
+ # We use urlparse as that will set `port` to None if there is no
+ # explicit port.
+ parsed_uri = urllib.parse.urlparse(uri)
- # set up the TLS connection params
+ # If this is a matrix:// URI check if the server has delegated matrix
+ # traffic using well-known delegation.
#
- # XXX disabling TLS is really only supported here for the benefit of the
- # unit tests. We should make the UTs cope with TLS rather than having to make
- # the code support the unit tests.
- if self._tls_client_options_factory is None:
- tls_options = None
- else:
- tls_options = self._tls_client_options_factory.get_options(
- res.tls_server_name.decode("ascii")
+ # We have to do this here and not in the endpoint as we need to rewrite
+ # the host header with the delegated server name.
+ delegated_server = None
+ if (
+ parsed_uri.scheme == b"matrix"
+ and not _is_ip_literal(parsed_uri.hostname)
+ and not parsed_uri.port
+ ):
+ well_known_result = yield self._well_known_resolver.get_well_known(
+ parsed_uri.hostname
+ )
+ delegated_server = well_known_result.delegated_server
+
+ if delegated_server:
+ # Ok, the server has delegated matrix traffic to somewhere else, so
+ # lets rewrite the URL to replace the server with the delegated
+ # server name.
+ uri = urllib.parse.urlunparse(
+ (
+ parsed_uri.scheme,
+ delegated_server,
+ parsed_uri.path,
+ parsed_uri.params,
+ parsed_uri.query,
+ parsed_uri.fragment,
+ )
)
+ parsed_uri = urllib.parse.urlparse(uri)
- # make sure that the Host header is set correctly
+ # We need to make sure the host header is set to the netloc of the
+ # server.
if headers is None:
headers = Headers()
else:
headers = headers.copy()
if not headers.hasHeader(b"host"):
- headers.addRawHeader(b"host", res.host_header)
+ headers.addRawHeader(b"host", parsed_uri.netloc)
- class EndpointFactory(object):
- @staticmethod
- def endpointForURI(_uri):
- ep = LoggingHostnameEndpoint(
- self._reactor, res.target_host, res.target_port
- )
- if tls_options is not None:
- ep = wrapClientTLS(tls_options, ep)
- return ep
-
- agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
res = yield make_deferred_yieldable(
- agent.request(method, uri, headers, bodyProducer)
+ self._agent.request(method, uri, headers, bodyProducer)
)
+
return res
- @defer.inlineCallbacks
- def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
- """Helper for `request`: determine the routing for a Matrix URI
- Args:
- parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
- parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
- if there is no explicit port given.
+@implementer(IAgentEndpointFactory)
+class MatrixHostnameEndpointFactory(object):
+ """Factory for MatrixHostnameEndpoint for parsing to an Agent.
+ """
- lookup_well_known (bool): True if we should look up the .well-known file if
- there is no SRV record.
+ def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+ self._reactor = reactor
+ self._tls_client_options_factory = tls_client_options_factory
- Returns:
- Deferred[_RoutingResult]
- """
- # check for an IP literal
- try:
- ip_address = IPAddress(parsed_uri.host.decode("ascii"))
- except Exception:
- # not an IP address
- ip_address = None
-
- if ip_address:
- port = parsed_uri.port
- if port == -1:
- port = 8448
- return _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=port,
- )
+ if srv_resolver is None:
+ srv_resolver = SrvResolver()
- if parsed_uri.port != -1:
- # there is an explicit port
- return _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=parsed_uri.port,
- )
+ self._srv_resolver = srv_resolver
- if lookup_well_known:
- # try a .well-known lookup
- well_known_result = yield self._well_known_resolver.get_well_known(
- parsed_uri.host
- )
- well_known_server = well_known_result.delegated_server
-
- if well_known_server:
- # if we found a .well-known, start again, but don't do another
- # .well-known lookup.
-
- # parse the server name in the .well-known response into host/port.
- # (This code is lifted from twisted.web.client.URI.fromBytes).
- if b":" in well_known_server:
- well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
- try:
- well_known_port = int(well_known_port)
- except ValueError:
- # the part after the colon could not be parsed as an int
- # - we assume it is an IPv6 literal with no port (the closing
- # ']' stops it being parsed as an int)
- well_known_host, well_known_port = well_known_server, -1
- else:
- well_known_host, well_known_port = well_known_server, -1
-
- new_uri = URI(
- scheme=parsed_uri.scheme,
- netloc=well_known_server,
- host=well_known_host,
- port=well_known_port,
- path=parsed_uri.path,
- params=parsed_uri.params,
- query=parsed_uri.query,
- fragment=parsed_uri.fragment,
- )
+ def endpointForURI(self, parsed_uri):
+ return MatrixHostnameEndpoint(
+ self._reactor,
+ self._tls_client_options_factory,
+ self._srv_resolver,
+ parsed_uri,
+ )
- res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
- return res
-
- # try a SRV lookup
- service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
- server_list = yield self._srv_resolver.resolve_service(service_name)
-
- if not server_list:
- target_host = parsed_uri.host
- port = 8448
- logger.debug(
- "No SRV record for %s, using %s:%i",
- parsed_uri.host.decode("ascii"),
- target_host.decode("ascii"),
- port,
- )
+
+@implementer(IStreamClientEndpoint)
+class MatrixHostnameEndpoint(object):
+ """An endpoint that resolves matrix:// URLs using Matrix server name
+ resolution (i.e. via SRV). Does not check for well-known delegation.
+
+ Args:
+ reactor (IReactor)
+ tls_client_options_factory (ClientTLSOptionsFactory|None):
+ factory to use for fetching client tls options, or none to disable TLS.
+ srv_resolver (SrvResolver): The SRV resolver to use
+ parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
+ to connect to.
+ """
+
+ def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+ self._reactor = reactor
+
+ self._parsed_uri = parsed_uri
+
+ # set up the TLS connection params
+ #
+ # XXX disabling TLS is really only supported here for the benefit of the
+ # unit tests. We should make the UTs cope with TLS rather than having to make
+ # the code support the unit tests.
+
+ if tls_client_options_factory is None:
+ self._tls_options = None
else:
- target_host, port = pick_server_from_list(server_list)
- logger.debug(
- "Picked %s:%i from SRV records for %s",
- target_host.decode("ascii"),
- port,
- parsed_uri.host.decode("ascii"),
+ self._tls_options = tls_client_options_factory.get_options(
+ self._parsed_uri.host
)
- return _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=target_host,
- target_port=port,
- )
+ self._srv_resolver = srv_resolver
+ def connect(self, protocol_factory):
+ """Implements IStreamClientEndpoint interface
+ """
-@implementer(IStreamClientEndpoint)
-class LoggingHostnameEndpoint(object):
- """A wrapper for HostnameEndpint which logs when it connects"""
+ return run_in_background(self._do_connect, protocol_factory)
- def __init__(self, reactor, host, port, *args, **kwargs):
- self.host = host
- self.port = port
- self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+ @defer.inlineCallbacks
+ def _do_connect(self, protocol_factory):
+ first_exception = None
+
+ server_list = yield self._resolve_server()
+
+ for server in server_list:
+ host = server.host
+ port = server.port
+
+ try:
+ logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+ endpoint = HostnameEndpoint(self._reactor, host, port)
+ if self._tls_options:
+ endpoint = wrapClientTLS(self._tls_options, endpoint)
+ result = yield make_deferred_yieldable(
+ endpoint.connect(protocol_factory)
+ )
- def connect(self, protocol_factory):
- logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
- return self.ep.connect(protocol_factory)
+ return result
+ except Exception as e:
+ logger.info(
+ "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e
+ )
+ if not first_exception:
+ first_exception = e
+ # We return the first failure because that's probably the most interesting.
+ if first_exception:
+ raise first_exception
-@attr.s
-class _RoutingResult(object):
- """The result returned by `_route_matrix_uri`.
+ # This shouldn't happen as we should always have at least one host/port
+ # to try and if that doesn't work then we'll have an exception.
+ raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
- Contains the parameters needed to direct a federation connection to a particular
- server.
+ @defer.inlineCallbacks
+ def _resolve_server(self):
+ """Resolves the server name to a list of hosts and ports to attempt to
+ connect to.
- Where a SRV record points to several servers, this object contains a single server
- chosen from the list.
- """
+ Returns:
+ Deferred[list[Server]]
+ """
- host_header = attr.ib()
- """
- The value we should assign to the Host header (host:port from the matrix
- URI, or .well-known).
+ if self._parsed_uri.scheme != b"matrix":
+ return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)]
- :type: bytes
- """
+ # Note: We don't do well-known lookup as that needs to have happened
+ # before now, due to needing to rewrite the Host header of the HTTP
+ # request.
- tls_server_name = attr.ib()
- """
- The server name we should set in the SNI (typically host, without port, from the
- matrix URI or .well-known)
+ # We reparse the URI so that defaultPort is -1 rather than 80
+ parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes())
- :type: bytes
- """
+ host = parsed_uri.hostname
+ port = parsed_uri.port
- target_host = attr.ib()
- """
- The hostname (or IP literal) we should route the TCP connection to (the target of the
- SRV record, or the hostname from the URL/.well-known)
+ # If there is an explicit port or the host is an IP address we bypass
+ # SRV lookups and just use the given host/port.
+ if port or _is_ip_literal(host):
+ return [Server(host, port or 8448)]
- :type: bytes
- """
+ server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
+
+ if server_list:
+ return server_list
+
+ # No SRV records, so we fallback to host and 8448
+ return [Server(host, 8448)]
- target_port = attr.ib()
- """
- The port we should route the TCP connection to (the target of the SRV record, or
- the port from the URL/.well-known, or 8448)
- :type: int
+def _is_ip_literal(host):
+ """Test if the given host name is either an IPv4 or IPv6 literal.
+
+ Args:
+ host (bytes)
+
+ Returns:
+ bool
"""
+
+ host = host.decode("ascii")
+
+ try:
+ IPAddress(host)
+ return True
+ except AddrFormatError:
+ return False
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b32188766d..3fe4ffb9e5 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-@attr.s
+@attr.s(slots=True, frozen=True)
class Server(object):
"""
Our record of an individual server which can be tried to reach a destination.
@@ -53,34 +53,47 @@ class Server(object):
expires = attr.ib(default=0)
-def pick_server_from_list(server_list):
- """Randomly choose a server from the server list
+def _sort_server_list(server_list):
+ """Given a list of SRV records sort them into priority order and shuffle
+ each priority with the given weight.
+ """
+ priority_map = {}
- Args:
- server_list (list[Server]): list of candidate servers
+ for server in server_list:
+ priority_map.setdefault(server.priority, []).append(server)
- Returns:
- Tuple[bytes, int]: (host, port) pair for the chosen server
- """
- if not server_list:
- raise RuntimeError("pick_server_from_list called with empty list")
+ results = []
+ for priority in sorted(priority_map):
+ servers = priority_map[priority]
+
+ # This algorithms roughly follows the algorithm described in RFC2782,
+ # changed to remove an off-by-one error.
+ #
+ # N.B. Weights can be zero, which means that they should be picked
+ # rarely.
+
+ total_weight = sum(s.weight for s in servers)
+
+ # Total weight can become zero if there are only zero weight servers
+ # left, which we handle by just shuffling and appending to the results.
+ while servers and total_weight:
+ target_weight = random.randint(1, total_weight)
- # TODO: currently we only use the lowest-priority servers. We should maintain a
- # cache of servers known to be "down" and filter them out
+ for s in servers:
+ target_weight -= s.weight
- min_priority = min(s.priority for s in server_list)
- eligible_servers = list(s for s in server_list if s.priority == min_priority)
- total_weight = sum(s.weight for s in eligible_servers)
- target_weight = random.randint(0, total_weight)
+ if target_weight <= 0:
+ break
- for s in eligible_servers:
- target_weight -= s.weight
+ results.append(s)
+ servers.remove(s)
+ total_weight -= s.weight
- if target_weight <= 0:
- return s.host, s.port
+ if servers:
+ random.shuffle(servers)
+ results.extend(servers)
- # this should be impossible.
- raise RuntimeError("pick_server_from_list got to end of eligible server list.")
+ return results
class SrvResolver(object):
@@ -120,7 +133,7 @@ class SrvResolver(object):
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
- return servers
+ return _sort_server_list(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
@@ -169,4 +182,4 @@ class SrvResolver(object):
)
self._cache[service_name] = list(servers)
- return servers
+ return _sort_server_list(servers)
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index d2866ff67d..7ddfad286d 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -32,22 +32,40 @@ from synapse.util.metrics import Measure
# period to cache .well-known results for by default
WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
-# jitter to add to the .well-known default cache ttl
-WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
+# jitter factor to add to the .well-known default cache ttls
+WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 0.1
# period to cache failure to fetch .well-known for
WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
+# period to cache failure to fetch .well-known if there has recently been a
+# valid well-known for that domain.
+WELL_KNOWN_DOWN_CACHE_PERIOD = 2 * 60
+
+# period to remember there was a valid well-known after valid record expires
+WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID = 2 * 3600
+
# cap for .well-known cache period
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
# lower bound for .well-known cache period
WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
+# Attempt to refetch a cached well-known N% of the TTL before it expires.
+# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then
+# we'll start trying to refetch 1 minute before it expires.
+WELL_KNOWN_GRACE_PERIOD_FACTOR = 0.2
+
+# Number of times we retry fetching a well-known for a domain we know recently
+# had a valid entry.
+WELL_KNOWN_RETRY_ATTEMPTS = 3
+
+
logger = logging.getLogger(__name__)
_well_known_cache = TTLCache("well-known")
+_had_valid_well_known_cache = TTLCache("had-valid-well-known")
@attr.s(slots=True, frozen=True)
@@ -59,14 +77,20 @@ class WellKnownResolver(object):
"""Handles well-known lookups for matrix servers.
"""
- def __init__(self, reactor, agent, well_known_cache=None):
+ def __init__(
+ self, reactor, agent, well_known_cache=None, had_well_known_cache=None
+ ):
self._reactor = reactor
self._clock = Clock(reactor)
if well_known_cache is None:
well_known_cache = _well_known_cache
+ if had_well_known_cache is None:
+ had_well_known_cache = _had_valid_well_known_cache
+
self._well_known_cache = well_known_cache
+ self._had_valid_well_known_cache = had_well_known_cache
self._well_known_agent = RedirectAgent(agent)
@defer.inlineCallbacks
@@ -80,59 +104,86 @@ class WellKnownResolver(object):
Deferred[WellKnownLookupResult]: The result of the lookup
"""
try:
- result = self._well_known_cache[server_name]
+ prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
+ server_name
+ )
+
+ now = self._clock.time()
+ if now < expiry - WELL_KNOWN_GRACE_PERIOD_FACTOR * ttl:
+ return WellKnownLookupResult(delegated_server=prev_result)
except KeyError:
- # TODO: should we linearise so that we don't end up doing two .well-known
- # requests for the same server in parallel?
+ prev_result = None
+
+ # TODO: should we linearise so that we don't end up doing two .well-known
+ # requests for the same server in parallel?
+ try:
with Measure(self._clock, "get_well_known"):
- result, cache_period = yield self._do_get_well_known(server_name)
+ result, cache_period = yield self._fetch_well_known(server_name)
+
+ except _FetchWellKnownFailure as e:
+ if prev_result and e.temporary:
+ # This is a temporary failure and we have a still valid cached
+ # result, so lets return that. Hopefully the next time we ask
+ # the remote will be back up again.
+ return WellKnownLookupResult(delegated_server=prev_result)
+
+ result = None
+
+ if self._had_valid_well_known_cache.get(server_name, False):
+ # We have recently seen a valid well-known record for this
+ # server, so we cache the lack of well-known for a shorter time.
+ cache_period = WELL_KNOWN_DOWN_CACHE_PERIOD
+ else:
+ cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
+
+ # add some randomness to the TTL to avoid a stampeding herd
+ cache_period *= random.uniform(
+ 1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+ 1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+ )
- if cache_period > 0:
- self._well_known_cache.set(server_name, result, cache_period)
+ if cache_period > 0:
+ self._well_known_cache.set(server_name, result, cache_period)
return WellKnownLookupResult(delegated_server=result)
@defer.inlineCallbacks
- def _do_get_well_known(self, server_name):
+ def _fetch_well_known(self, server_name):
"""Actually fetch and parse a .well-known, without checking the cache
Args:
server_name (bytes): name of the server, from the requested url
+ Raises:
+ _FetchWellKnownFailure if we fail to lookup a result
+
Returns:
- Deferred[Tuple[bytes|None|object],int]:
- result, cache period, where result is one of:
- - the new server name from the .well-known (as a `bytes`)
- - None if there was no .well-known file.
- - INVALID_WELL_KNOWN if the .well-known was invalid
+ Deferred[Tuple[bytes,int]]: The lookup result and cache period.
"""
- uri = b"https://%s/.well-known/matrix/server" % (server_name,)
- uri_str = uri.decode("ascii")
- logger.info("Fetching %s", uri_str)
+
+ had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
+
+ # We do this in two steps to differentiate between possibly transient
+ # errors (e.g. can't connect to host, 503 response) and more permenant
+ # errors (such as getting a 404 response).
+ response, body = yield self._make_well_known_request(
+ server_name, retry=had_valid_well_known
+ )
+
try:
- response = yield make_deferred_yieldable(
- self._well_known_agent.request(b"GET", uri)
- )
- body = yield make_deferred_yieldable(readBody(response))
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code,))
parsed_body = json.loads(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body)
- if not isinstance(parsed_body, dict):
- raise Exception("not a dict")
- if "m.server" not in parsed_body:
- raise Exception("Missing key 'm.server'")
- except Exception as e:
- logger.info("Error fetching %s: %s", uri_str, e)
-
- # add some randomness to the TTL to avoid a stampeding herd every hour
- # after startup
- cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
- cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
- return (None, cache_period)
- result = parsed_body["m.server"].encode("ascii")
+ result = parsed_body["m.server"].encode("ascii")
+ except defer.CancelledError:
+ # Bail if we've been cancelled
+ raise
+ except Exception as e:
+ logger.info("Error parsing well-known for %s: %s", server_name, e)
+ raise _FetchWellKnownFailure(temporary=False)
cache_period = _cache_period_from_headers(
response.headers, time_now=self._reactor.seconds
@@ -141,12 +192,68 @@ class WellKnownResolver(object):
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
# add some randomness to the TTL to avoid a stampeding herd every 24 hours
# after startup
- cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+ cache_period *= random.uniform(
+ 1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+ 1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+ )
else:
cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD)
- return (result, cache_period)
+ # We got a success, mark as such in the cache
+ self._had_valid_well_known_cache.set(
+ server_name,
+ bool(result),
+ cache_period + WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID,
+ )
+
+ return result, cache_period
+
+ @defer.inlineCallbacks
+ def _make_well_known_request(self, server_name, retry):
+ """Make the well known request.
+
+ This will retry the request if requested and it fails (with unable
+ to connect or receives a 5xx error).
+
+ Args:
+ server_name (bytes)
+ retry (bool): Whether to retry the request if it fails.
+
+ Returns:
+ Deferred[tuple[IResponse, bytes]] Returns the response object and
+ body. Response may be a non-200 response.
+ """
+ uri = b"https://%s/.well-known/matrix/server" % (server_name,)
+ uri_str = uri.decode("ascii")
+
+ i = 0
+ while True:
+ i += 1
+
+ logger.info("Fetching %s", uri_str)
+ try:
+ response = yield make_deferred_yieldable(
+ self._well_known_agent.request(b"GET", uri)
+ )
+ body = yield make_deferred_yieldable(readBody(response))
+
+ if 500 <= response.code < 600:
+ raise Exception("Non-200 response %s" % (response.code,))
+
+ return response, body
+ except defer.CancelledError:
+ # Bail if we've been cancelled
+ raise
+ except Exception as e:
+ if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS:
+ logger.info("Error fetching %s: %s", uri_str, e)
+ raise _FetchWellKnownFailure(temporary=True)
+
+ logger.info("Error fetching %s: %s. Retrying", uri_str, e)
+
+ # Sleep briefly in the hopes that they come back up
+ yield self._clock.sleep(0.5)
def _cache_period_from_headers(headers, time_now=time.time):
@@ -185,3 +292,10 @@ def _parse_cache_control(headers):
v = splits[1] if len(splits) > 1 else None
cache_controls[k] = v
return cache_controls
+
+
+@attr.s()
+class _FetchWellKnownFailure(Exception):
+ # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
+ # a temporary failure.
+ temporary = attr.ib()
|