diff options
Diffstat (limited to 'synapse/http')
-rw-r--r-- | synapse/http/federation/matrix_federation_agent.py | 502 | ||||
-rw-r--r-- | synapse/http/federation/srv_resolver.py | 61 | ||||
-rw-r--r-- | synapse/http/federation/well_known_resolver.py | 301 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 23 | ||||
-rw-r--r-- | synapse/http/servlet.py | 2 |
5 files changed, 534 insertions, 355 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index a0d5139839..feae7de5be 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -12,49 +12,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json + import logging -import random -import time +import urllib -import attr -from netaddr import IPAddress +from netaddr import AddrFormatError, IPAddress from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody -from twisted.web.http import stringToDatetime +from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent +from twisted.web.iweb import IAgent, IAgentEndpointFactory -from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list -from synapse.logging.context import make_deferred_yieldable +from synapse.http.federation.srv_resolver import Server, SrvResolver +from synapse.http.federation.well_known_resolver import WellKnownResolver +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util import Clock -from synapse.util.caches.ttlcache import TTLCache -from synapse.util.metrics import Measure - -# period to cache .well-known results for by default -WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600 - -# jitter to add to the .well-known default cache ttl -WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60 - -# period to cache failure to fetch .well-known for -WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 - -# cap for .well-known cache period -WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 logger = logging.getLogger(__name__) -well_known_cache = TTLCache("well-known") @implementer(IAgent) class MatrixFederationAgent(object): - """An Agent-like thing which provides a `request` method which will look up a matrix - server and send an HTTP request to it. + """An Agent-like thing which provides a `request` method which correctly + handles resolving matrix server names when using matrix://. Handles standard + https URIs as normal. Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) @@ -68,9 +52,9 @@ class MatrixFederationAgent(object): SRVResolver impl to use for looking up SRV records. None to use a default implementation. - _well_known_cache (TTLCache|None): - TTLCache impl for storing cached well-known lookups. None to use a default - implementation. + _well_known_resolver (WellKnownResolver|None): + WellKnownResolver to use to perform well-known lookups. None to use a + default implementation. """ def __init__( @@ -78,54 +62,49 @@ class MatrixFederationAgent(object): reactor, tls_client_options_factory, _srv_resolver=None, - _well_known_cache=well_known_cache, + _well_known_resolver=None, ): self._reactor = reactor self._clock = Clock(reactor) - - self._tls_client_options_factory = tls_client_options_factory - if _srv_resolver is None: - _srv_resolver = SrvResolver() - self._srv_resolver = _srv_resolver - self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False self._pool.maxPersistentPerHost = 5 self._pool.cachedConnectionTimeout = 2 * 60 - _well_known_agent = RedirectAgent( - Agent( + self._agent = Agent.usingEndpointFactory( + self._reactor, + MatrixHostnameEndpointFactory( + reactor, tls_client_options_factory, _srv_resolver + ), + pool=self._pool, + ) + + if _well_known_resolver is None: + _well_known_resolver = WellKnownResolver( self._reactor, - pool=self._pool, - contextFactory=tls_client_options_factory, + agent=Agent( + self._reactor, + pool=self._pool, + contextFactory=tls_client_options_factory, + ), ) - ) - self._well_known_agent = _well_known_agent - # our cache of .well-known lookup results, mapping from server name - # to delegated name. The values can be: - # `bytes`: a valid server-name - # `None`: there is no (valid) .well-known here - self._well_known_cache = _well_known_cache + self._well_known_resolver = _well_known_resolver @defer.inlineCallbacks def request(self, method, uri, headers=None, bodyProducer=None): """ Args: method (bytes): HTTP method: GET/POST/etc - uri (bytes): Absolute URI to be retrieved - headers (twisted.web.http_headers.Headers|None): HTTP headers to send with the request, or None to send no extra headers. - bodyProducer (twisted.web.iweb.IBodyProducer|None): An object which can generate bytes to make up the body of this request (for example, the properly encoded contents of a file for a file upload). Or None if the request is to have no body. - Returns: Deferred[twisted.web.iweb.IResponse]: fires when the header of the response has been received (regardless of the @@ -133,324 +112,207 @@ class MatrixFederationAgent(object): response from being received (including problems that prevent the request from being sent). """ - parsed_uri = URI.fromBytes(uri, defaultPort=-1) - res = yield self._route_matrix_uri(parsed_uri) + # We use urlparse as that will set `port` to None if there is no + # explicit port. + parsed_uri = urllib.parse.urlparse(uri) - # set up the TLS connection params + # If this is a matrix:// URI check if the server has delegated matrix + # traffic using well-known delegation. # - # XXX disabling TLS is really only supported here for the benefit of the - # unit tests. We should make the UTs cope with TLS rather than having to make - # the code support the unit tests. - if self._tls_client_options_factory is None: - tls_options = None - else: - tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii") + # We have to do this here and not in the endpoint as we need to rewrite + # the host header with the delegated server name. + delegated_server = None + if ( + parsed_uri.scheme == b"matrix" + and not _is_ip_literal(parsed_uri.hostname) + and not parsed_uri.port + ): + well_known_result = yield self._well_known_resolver.get_well_known( + parsed_uri.hostname ) + delegated_server = well_known_result.delegated_server + + if delegated_server: + # Ok, the server has delegated matrix traffic to somewhere else, so + # lets rewrite the URL to replace the server with the delegated + # server name. + uri = urllib.parse.urlunparse( + ( + parsed_uri.scheme, + delegated_server, + parsed_uri.path, + parsed_uri.params, + parsed_uri.query, + parsed_uri.fragment, + ) + ) + parsed_uri = urllib.parse.urlparse(uri) - # make sure that the Host header is set correctly + # We need to make sure the host header is set to the netloc of the + # server. if headers is None: headers = Headers() else: headers = headers.copy() if not headers.hasHeader(b"host"): - headers.addRawHeader(b"host", res.host_header) - - class EndpointFactory(object): - @staticmethod - def endpointForURI(_uri): - ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port - ) - if tls_options is not None: - ep = wrapClientTLS(tls_options, ep) - return ep + headers.addRawHeader(b"host", parsed_uri.netloc) - agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool) res = yield make_deferred_yieldable( - agent.request(method, uri, headers, bodyProducer) + self._agent.request(method, uri, headers, bodyProducer) ) - return res - - @defer.inlineCallbacks - def _route_matrix_uri(self, parsed_uri, lookup_well_known=True): - """Helper for `request`: determine the routing for a Matrix URI - Args: - parsed_uri (twisted.web.client.URI): uri to route. Note that it should be - parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1 - if there is no explicit port given. + return res - lookup_well_known (bool): True if we should look up the .well-known file if - there is no SRV record. - Returns: - Deferred[_RoutingResult] - """ - # check for an IP literal - try: - ip_address = IPAddress(parsed_uri.host.decode("ascii")) - except Exception: - # not an IP address - ip_address = None - - if ip_address: - port = parsed_uri.port - if port == -1: - port = 8448 - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - ) +@implementer(IAgentEndpointFactory) +class MatrixHostnameEndpointFactory(object): + """Factory for MatrixHostnameEndpoint for parsing to an Agent. + """ - if parsed_uri.port != -1: - # there is an explicit port - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - ) + def __init__(self, reactor, tls_client_options_factory, srv_resolver): + self._reactor = reactor + self._tls_client_options_factory = tls_client_options_factory - if lookup_well_known: - # try a .well-known lookup - well_known_server = yield self._get_well_known(parsed_uri.host) - - if well_known_server: - # if we found a .well-known, start again, but don't do another - # .well-known lookup. - - # parse the server name in the .well-known response into host/port. - # (This code is lifted from twisted.web.client.URI.fromBytes). - if b":" in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b":", 1) - try: - well_known_port = int(well_known_port) - except ValueError: - # the part after the colon could not be parsed as an int - # - we assume it is an IPv6 literal with no port (the closing - # ']' stops it being parsed as an int) - well_known_host, well_known_port = well_known_server, -1 - else: - well_known_host, well_known_port = well_known_server, -1 - - new_uri = URI( - scheme=parsed_uri.scheme, - netloc=well_known_server, - host=well_known_host, - port=well_known_port, - path=parsed_uri.path, - params=parsed_uri.params, - query=parsed_uri.query, - fragment=parsed_uri.fragment, - ) + if srv_resolver is None: + srv_resolver = SrvResolver() - res = yield self._route_matrix_uri(new_uri, lookup_well_known=False) - return res - - # try a SRV lookup - service_name = b"_matrix._tcp.%s" % (parsed_uri.host,) - server_list = yield self._srv_resolver.resolve_service(service_name) - - if not server_list: - target_host = parsed_uri.host - port = 8448 - logger.debug( - "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), - target_host.decode("ascii"), - port, - ) - else: - target_host, port = pick_server_from_list(server_list) - logger.debug( - "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), - port, - parsed_uri.host.decode("ascii"), - ) + self._srv_resolver = srv_resolver - return _RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, + def endpointForURI(self, parsed_uri): + return MatrixHostnameEndpoint( + self._reactor, + self._tls_client_options_factory, + self._srv_resolver, + parsed_uri, ) - @defer.inlineCallbacks - def _get_well_known(self, server_name): - """Attempt to fetch and parse a .well-known file for the given server - Args: - server_name (bytes): name of the server, from the requested url - - Returns: - Deferred[bytes|None]: either the new server name, from the .well-known, or - None if there was no .well-known file. - """ - try: - result = self._well_known_cache[server_name] - except KeyError: - # TODO: should we linearise so that we don't end up doing two .well-known - # requests for the same server in parallel? - with Measure(self._clock, "get_well_known"): - result, cache_period = yield self._do_get_well_known(server_name) +@implementer(IStreamClientEndpoint) +class MatrixHostnameEndpoint(object): + """An endpoint that resolves matrix:// URLs using Matrix server name + resolution (i.e. via SRV). Does not check for well-known delegation. - if cache_period > 0: - self._well_known_cache.set(server_name, result, cache_period) + Args: + reactor (IReactor) + tls_client_options_factory (ClientTLSOptionsFactory|None): + factory to use for fetching client tls options, or none to disable TLS. + srv_resolver (SrvResolver): The SRV resolver to use + parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting + to connect to. + """ - return result + def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri): + self._reactor = reactor - @defer.inlineCallbacks - def _do_get_well_known(self, server_name): - """Actually fetch and parse a .well-known, without checking the cache + self._parsed_uri = parsed_uri - Args: - server_name (bytes): name of the server, from the requested url + # set up the TLS connection params + # + # XXX disabling TLS is really only supported here for the benefit of the + # unit tests. We should make the UTs cope with TLS rather than having to make + # the code support the unit tests. - Returns: - Deferred[Tuple[bytes|None|object],int]: - result, cache period, where result is one of: - - the new server name from the .well-known (as a `bytes`) - - None if there was no .well-known file. - - INVALID_WELL_KNOWN if the .well-known was invalid - """ - uri = b"https://%s/.well-known/matrix/server" % (server_name,) - uri_str = uri.decode("ascii") - logger.info("Fetching %s", uri_str) - try: - response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri) - ) - body = yield make_deferred_yieldable(readBody(response)) - if response.code != 200: - raise Exception("Non-200 response %s" % (response.code,)) - - parsed_body = json.loads(body.decode("utf-8")) - logger.info("Response from .well-known: %s", parsed_body) - if not isinstance(parsed_body, dict): - raise Exception("not a dict") - if "m.server" not in parsed_body: - raise Exception("Missing key 'm.server'") - except Exception as e: - logger.info("Error fetching %s: %s", uri_str, e) - - # add some randomness to the TTL to avoid a stampeding herd every hour - # after startup - cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD - cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) - return (None, cache_period) - - result = parsed_body["m.server"].encode("ascii") - - cache_period = _cache_period_from_headers( - response.headers, time_now=self._reactor.seconds - ) - if cache_period is None: - cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD - # add some randomness to the TTL to avoid a stampeding herd every 24 hours - # after startup - cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) + if tls_client_options_factory is None: + self._tls_options = None else: - cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) - - return (result, cache_period) - - -@implementer(IStreamClientEndpoint) -class LoggingHostnameEndpoint(object): - """A wrapper for HostnameEndpint which logs when it connects""" + self._tls_options = tls_client_options_factory.get_options( + self._parsed_uri.host.decode("ascii") + ) - def __init__(self, reactor, host, port, *args, **kwargs): - self.host = host - self.port = port - self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs) + self._srv_resolver = srv_resolver def connect(self, protocol_factory): - logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port) - return self.ep.connect(protocol_factory) + """Implements IStreamClientEndpoint interface + """ + return run_in_background(self._do_connect, protocol_factory) -def _cache_period_from_headers(headers, time_now=time.time): - cache_controls = _parse_cache_control(headers) + @defer.inlineCallbacks + def _do_connect(self, protocol_factory): + first_exception = None + + server_list = yield self._resolve_server() + + for server in server_list: + host = server.host + port = server.port + + try: + logger.info("Connecting to %s:%i", host.decode("ascii"), port) + endpoint = HostnameEndpoint(self._reactor, host, port) + if self._tls_options: + endpoint = wrapClientTLS(self._tls_options, endpoint) + result = yield make_deferred_yieldable( + endpoint.connect(protocol_factory) + ) - if b"no-store" in cache_controls: - return 0 + return result + except Exception as e: + logger.info( + "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e + ) + if not first_exception: + first_exception = e - if b"max-age" in cache_controls: - try: - max_age = int(cache_controls[b"max-age"]) - return max_age - except ValueError: - pass + # We return the first failure because that's probably the most interesting. + if first_exception: + raise first_exception - expires = headers.getRawHeaders(b"expires") - if expires is not None: - try: - expires_date = stringToDatetime(expires[-1]) - return expires_date - time_now() - except ValueError: - # RFC7234 says 'A cache recipient MUST interpret invalid date formats, - # especially the value "0", as representing a time in the past (i.e., - # "already expired"). - return 0 + # This shouldn't happen as we should always have at least one host/port + # to try and if that doesn't work then we'll have an exception. + raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,)) - return None + @defer.inlineCallbacks + def _resolve_server(self): + """Resolves the server name to a list of hosts and ports to attempt to + connect to. + Returns: + Deferred[list[Server]] + """ -def _parse_cache_control(headers): - cache_controls = {} - for hdr in headers.getRawHeaders(b"cache-control", []): - for directive in hdr.split(b","): - splits = [x.strip() for x in directive.split(b"=", 1)] - k = splits[0].lower() - v = splits[1] if len(splits) > 1 else None - cache_controls[k] = v - return cache_controls + if self._parsed_uri.scheme != b"matrix": + return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)] + # Note: We don't do well-known lookup as that needs to have happened + # before now, due to needing to rewrite the Host header of the HTTP + # request. -@attr.s -class _RoutingResult(object): - """The result returned by `_route_matrix_uri`. + # We reparse the URI so that defaultPort is -1 rather than 80 + parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes()) - Contains the parameters needed to direct a federation connection to a particular - server. + host = parsed_uri.hostname + port = parsed_uri.port - Where a SRV record points to several servers, this object contains a single server - chosen from the list. - """ + # If there is an explicit port or the host is an IP address we bypass + # SRV lookups and just use the given host/port. + if port or _is_ip_literal(host): + return [Server(host, port or 8448)] - host_header = attr.ib() - """ - The value we should assign to the Host header (host:port from the matrix - URI, or .well-known). + server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host) - :type: bytes - """ + if server_list: + return server_list - tls_server_name = attr.ib() - """ - The server name we should set in the SNI (typically host, without port, from the - matrix URI or .well-known) + # No SRV records, so we fallback to host and 8448 + return [Server(host, 8448)] - :type: bytes - """ - target_host = attr.ib() - """ - The hostname (or IP literal) we should route the TCP connection to (the target of the - SRV record, or the hostname from the URL/.well-known) +def _is_ip_literal(host): + """Test if the given host name is either an IPv4 or IPv6 literal. - :type: bytes - """ + Args: + host (bytes) - target_port = attr.ib() + Returns: + bool """ - The port we should route the TCP connection to (the target of the SRV record, or - the port from the URL/.well-known, or 8448) - :type: int - """ + host = host.decode("ascii") + + try: + IPAddress(host) + return True + except AddrFormatError: + return False diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index b32188766d..3fe4ffb9e5 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) SERVER_CACHE = {} -@attr.s +@attr.s(slots=True, frozen=True) class Server(object): """ Our record of an individual server which can be tried to reach a destination. @@ -53,34 +53,47 @@ class Server(object): expires = attr.ib(default=0) -def pick_server_from_list(server_list): - """Randomly choose a server from the server list +def _sort_server_list(server_list): + """Given a list of SRV records sort them into priority order and shuffle + each priority with the given weight. + """ + priority_map = {} - Args: - server_list (list[Server]): list of candidate servers + for server in server_list: + priority_map.setdefault(server.priority, []).append(server) - Returns: - Tuple[bytes, int]: (host, port) pair for the chosen server - """ - if not server_list: - raise RuntimeError("pick_server_from_list called with empty list") + results = [] + for priority in sorted(priority_map): + servers = priority_map[priority] + + # This algorithms roughly follows the algorithm described in RFC2782, + # changed to remove an off-by-one error. + # + # N.B. Weights can be zero, which means that they should be picked + # rarely. + + total_weight = sum(s.weight for s in servers) + + # Total weight can become zero if there are only zero weight servers + # left, which we handle by just shuffling and appending to the results. + while servers and total_weight: + target_weight = random.randint(1, total_weight) - # TODO: currently we only use the lowest-priority servers. We should maintain a - # cache of servers known to be "down" and filter them out + for s in servers: + target_weight -= s.weight - min_priority = min(s.priority for s in server_list) - eligible_servers = list(s for s in server_list if s.priority == min_priority) - total_weight = sum(s.weight for s in eligible_servers) - target_weight = random.randint(0, total_weight) + if target_weight <= 0: + break - for s in eligible_servers: - target_weight -= s.weight + results.append(s) + servers.remove(s) + total_weight -= s.weight - if target_weight <= 0: - return s.host, s.port + if servers: + random.shuffle(servers) + results.extend(servers) - # this should be impossible. - raise RuntimeError("pick_server_from_list got to end of eligible server list.") + return results class SrvResolver(object): @@ -120,7 +133,7 @@ class SrvResolver(object): if cache_entry: if all(s.expires > now for s in cache_entry): servers = list(cache_entry) - return servers + return _sort_server_list(servers) try: answers, _, _ = yield make_deferred_yieldable( @@ -169,4 +182,4 @@ class SrvResolver(object): ) self._cache[service_name] = list(servers) - return servers + return _sort_server_list(servers) diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py new file mode 100644 index 0000000000..5e9b0befb0 --- /dev/null +++ b/synapse/http/federation/well_known_resolver.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import random +import time + +import attr + +from twisted.internet import defer +from twisted.web.client import RedirectAgent, readBody +from twisted.web.http import stringToDatetime + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import Clock +from synapse.util.caches.ttlcache import TTLCache +from synapse.util.metrics import Measure + +# period to cache .well-known results for by default +WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600 + +# jitter factor to add to the .well-known default cache ttls +WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 0.1 + +# period to cache failure to fetch .well-known for +WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 + +# period to cache failure to fetch .well-known if there has recently been a +# valid well-known for that domain. +WELL_KNOWN_DOWN_CACHE_PERIOD = 2 * 60 + +# period to remember there was a valid well-known after valid record expires +WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID = 2 * 3600 + +# cap for .well-known cache period +WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 + +# lower bound for .well-known cache period +WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60 + +# Attempt to refetch a cached well-known N% of the TTL before it expires. +# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then +# we'll start trying to refetch 1 minute before it expires. +WELL_KNOWN_GRACE_PERIOD_FACTOR = 0.2 + +# Number of times we retry fetching a well-known for a domain we know recently +# had a valid entry. +WELL_KNOWN_RETRY_ATTEMPTS = 3 + + +logger = logging.getLogger(__name__) + + +_well_known_cache = TTLCache("well-known") +_had_valid_well_known_cache = TTLCache("had-valid-well-known") + + +@attr.s(slots=True, frozen=True) +class WellKnownLookupResult(object): + delegated_server = attr.ib() + + +class WellKnownResolver(object): + """Handles well-known lookups for matrix servers. + """ + + def __init__( + self, reactor, agent, well_known_cache=None, had_well_known_cache=None + ): + self._reactor = reactor + self._clock = Clock(reactor) + + if well_known_cache is None: + well_known_cache = _well_known_cache + + if had_well_known_cache is None: + had_well_known_cache = _had_valid_well_known_cache + + self._well_known_cache = well_known_cache + self._had_valid_well_known_cache = had_well_known_cache + self._well_known_agent = RedirectAgent(agent) + + @defer.inlineCallbacks + def get_well_known(self, server_name): + """Attempt to fetch and parse a .well-known file for the given server + + Args: + server_name (bytes): name of the server, from the requested url + + Returns: + Deferred[WellKnownLookupResult]: The result of the lookup + """ + try: + prev_result, expiry, ttl = self._well_known_cache.get_with_expiry( + server_name + ) + + now = self._clock.time() + if now < expiry - WELL_KNOWN_GRACE_PERIOD_FACTOR * ttl: + return WellKnownLookupResult(delegated_server=prev_result) + except KeyError: + prev_result = None + + # TODO: should we linearise so that we don't end up doing two .well-known + # requests for the same server in parallel? + try: + with Measure(self._clock, "get_well_known"): + result, cache_period = yield self._fetch_well_known(server_name) + + except _FetchWellKnownFailure as e: + if prev_result and e.temporary: + # This is a temporary failure and we have a still valid cached + # result, so lets return that. Hopefully the next time we ask + # the remote will be back up again. + return WellKnownLookupResult(delegated_server=prev_result) + + result = None + + if self._had_valid_well_known_cache.get(server_name, False): + # We have recently seen a valid well-known record for this + # server, so we cache the lack of well-known for a shorter time. + cache_period = WELL_KNOWN_DOWN_CACHE_PERIOD + else: + cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD + + # add some randomness to the TTL to avoid a stampeding herd + cache_period *= random.uniform( + 1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER, + 1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER, + ) + + if cache_period > 0: + self._well_known_cache.set(server_name, result, cache_period) + + return WellKnownLookupResult(delegated_server=result) + + @defer.inlineCallbacks + def _fetch_well_known(self, server_name): + """Actually fetch and parse a .well-known, without checking the cache + + Args: + server_name (bytes): name of the server, from the requested url + + Raises: + _FetchWellKnownFailure if we fail to lookup a result + + Returns: + Deferred[Tuple[bytes,int]]: The lookup result and cache period. + """ + + had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False) + + # We do this in two steps to differentiate between possibly transient + # errors (e.g. can't connect to host, 503 response) and more permenant + # errors (such as getting a 404 response). + response, body = yield self._make_well_known_request( + server_name, retry=had_valid_well_known + ) + + try: + if response.code != 200: + raise Exception("Non-200 response %s" % (response.code,)) + + parsed_body = json.loads(body.decode("utf-8")) + logger.info("Response from .well-known: %s", parsed_body) + + result = parsed_body["m.server"].encode("ascii") + except defer.CancelledError: + # Bail if we've been cancelled + raise + except Exception as e: + logger.info("Error parsing well-known for %s: %s", server_name, e) + raise _FetchWellKnownFailure(temporary=False) + + cache_period = _cache_period_from_headers( + response.headers, time_now=self._reactor.seconds + ) + if cache_period is None: + cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD + # add some randomness to the TTL to avoid a stampeding herd every 24 hours + # after startup + cache_period *= random.uniform( + 1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER, + 1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER, + ) + else: + cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) + cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD) + + # We got a success, mark as such in the cache + self._had_valid_well_known_cache.set( + server_name, + bool(result), + cache_period + WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID, + ) + + return (result, cache_period) + + @defer.inlineCallbacks + def _make_well_known_request(self, server_name, retry): + """Make the well known request. + + This will retry the request if requested and it fails (with unable + to connect or receives a 5xx error). + + Args: + server_name (bytes) + retry (bool): Whether to retry the request if it fails. + + Returns: + Deferred[tuple[IResponse, bytes]] Returns the response object and + body. Response may be a non-200 response. + """ + uri = b"https://%s/.well-known/matrix/server" % (server_name,) + uri_str = uri.decode("ascii") + + i = 0 + while True: + i += 1 + + logger.info("Fetching %s", uri_str) + try: + response = yield make_deferred_yieldable( + self._well_known_agent.request(b"GET", uri) + ) + body = yield make_deferred_yieldable(readBody(response)) + + if 500 <= response.code < 600: + raise Exception("Non-200 response %s" % (response.code,)) + + return response, body + except defer.CancelledError: + # Bail if we've been cancelled + raise + except Exception as e: + if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS: + logger.info("Error fetching %s: %s", uri_str, e) + raise _FetchWellKnownFailure(temporary=True) + + logger.info("Error fetching %s: %s. Retrying", uri_str, e) + + # Sleep briefly in the hopes that they come back up + yield self._clock.sleep(0.5) + + +def _cache_period_from_headers(headers, time_now=time.time): + cache_controls = _parse_cache_control(headers) + + if b"no-store" in cache_controls: + return 0 + + if b"max-age" in cache_controls: + try: + max_age = int(cache_controls[b"max-age"]) + return max_age + except ValueError: + pass + + expires = headers.getRawHeaders(b"expires") + if expires is not None: + try: + expires_date = stringToDatetime(expires[-1]) + return expires_date - time_now() + except ValueError: + # RFC7234 says 'A cache recipient MUST interpret invalid date formats, + # especially the value "0", as representing a time in the past (i.e., + # "already expired"). + return 0 + + return None + + +def _parse_cache_control(headers): + cache_controls = {} + for hdr in headers.getRawHeaders(b"cache-control", []): + for directive in hdr.split(b","): + splits = [x.strip() for x in directive.split(b"=", 1)] + k = splits[0].lower() + v = splits[1] if len(splits) > 1 else None + cache_controls[k] = v + return cache_controls + + +@attr.s() +class _FetchWellKnownFailure(Exception): + # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be + # a temporary failure. + temporary = attr.ib() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index d07d356464..4326e98a28 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -36,7 +36,6 @@ from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers -import synapse.logging.opentracing as opentracing import synapse.metrics import synapse.util.retryutils from synapse.api.errors import ( @@ -50,6 +49,12 @@ from synapse.http import QuieterFileBodyProducer from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.logging.context import make_deferred_yieldable +from synapse.logging.opentracing import ( + inject_active_span_byte_dict, + set_tag, + start_active_span, + tags, +) from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -341,20 +346,20 @@ class MatrixFederationHttpClient(object): query_bytes = b"" # Retreive current span - scope = opentracing.start_active_span( + scope = start_active_span( "outgoing-federation-request", tags={ - opentracing.tags.SPAN_KIND: opentracing.tags.SPAN_KIND_RPC_CLIENT, - opentracing.tags.PEER_ADDRESS: request.destination, - opentracing.tags.HTTP_METHOD: request.method, - opentracing.tags.HTTP_URL: request.path, + tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT, + tags.PEER_ADDRESS: request.destination, + tags.HTTP_METHOD: request.method, + tags.HTTP_URL: request.path, }, finish_on_close=True, ) # Inject the span into the headers headers_dict = {} - opentracing.inject_active_span_byte_dict(headers_dict, request.destination) + inject_active_span_byte_dict(headers_dict, request.destination) headers_dict[b"User-Agent"] = [self.version_string_bytes] @@ -436,9 +441,7 @@ class MatrixFederationHttpClient(object): response.phrase.decode("ascii", errors="replace"), ) - opentracing.set_tag( - opentracing.tags.HTTP_STATUS_CODE, response.code - ) + set_tag(tags.HTTP_STATUS_CODE, response.code) if 200 <= response.code < 300: pass diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index fd07bf7b8e..c186b31f59 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -300,7 +300,7 @@ class RestServlet(object): http_server.register_paths( method, patterns, - trace_servlet(servlet_classname, method_handler), + trace_servlet(servlet_classname)(method_handler), servlet_classname, ) |