diff --git a/synapse/http/client.py b/synapse/http/client.py
index afcf698b29..47a1f82ff0 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -333,9 +333,10 @@ class SimpleHttpClient(object):
"POST", uri, headers=Headers(actual_headers), data=query_bytes
)
+ body = yield make_deferred_yieldable(readBody(response))
+
if 200 <= response.code < 300:
- body = yield make_deferred_yieldable(treq.json_content(response))
- defer.returnValue(body)
+ defer.returnValue(json.loads(body))
else:
raise HttpResponseException(response.code, response.phrase, body)
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 64c780a341..384d8a37a2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,20 +12,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import logging
+import random
+import time
+import attr
+from netaddr import IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.web.client import URI, Agent, HTTPConnectionPool
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
+from twisted.web.http import stringToDatetime
+from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
-from synapse.http.endpoint import parse_server_name
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.util import Clock
+from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.metrics import Measure
+
+# period to cache .well-known results for by default
+WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
+
+# jitter to add to the .well-known default cache ttl
+WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
+
+# period to cache failure to fetch .well-known for
+WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
+
+# cap for .well-known cache period
+WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
logger = logging.getLogger(__name__)
+well_known_cache = TTLCache('well-known')
@implementer(IAgent)
@@ -41,15 +64,24 @@ class MatrixFederationAgent(object):
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
+ _well_known_tls_policy (IPolicyForHTTPS|None):
+ TLS policy to use for fetching .well-known files. None to use a default
+ (browser-like) implementation.
+
srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
"""
def __init__(
- self, reactor, tls_client_options_factory, _srv_resolver=None,
+ self, reactor, tls_client_options_factory,
+ _well_known_tls_policy=None,
+ _srv_resolver=None,
+ _well_known_cache=well_known_cache,
):
self._reactor = reactor
+ self._clock = Clock(reactor)
+
self._tls_client_options_factory = tls_client_options_factory
if _srv_resolver is None:
_srv_resolver = SrvResolver()
@@ -60,6 +92,22 @@ class MatrixFederationAgent(object):
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
+ agent_args = {}
+ if _well_known_tls_policy is not None:
+ # the param is called 'contextFactory', but actually passing a
+ # contextfactory is deprecated, and it expects an IPolicyForHTTPS.
+ agent_args['contextFactory'] = _well_known_tls_policy
+ _well_known_agent = RedirectAgent(
+ Agent(self._reactor, pool=self._pool, **agent_args),
+ )
+ self._well_known_agent = _well_known_agent
+
+ # our cache of .well-known lookup results, mapping from server name
+ # to delegated name. The values can be:
+ # `bytes`: a valid server-name
+ # `None`: there is no (valid) .well-known here
+ self._well_known_cache = _well_known_cache
+
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
"""
@@ -85,34 +133,36 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request
from being sent).
"""
+ parsed_uri = URI.fromBytes(uri, defaultPort=-1)
+ res = yield self._route_matrix_uri(parsed_uri)
- parsed_uri = URI.fromBytes(uri)
- server_name_bytes = parsed_uri.netloc
- host, port = parse_server_name(server_name_bytes.decode("ascii"))
-
+ # set up the TLS connection params
+ #
# XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests.
if self._tls_client_options_factory is None:
tls_options = None
else:
- tls_options = self._tls_client_options_factory.get_options(host)
+ tls_options = self._tls_client_options_factory.get_options(
+ res.tls_server_name.decode("ascii")
+ )
- if port is not None:
- target = (host, port)
+ # make sure that the Host header is set correctly
+ if headers is None:
+ headers = Headers()
else:
- server_list = yield self._srv_resolver.resolve_service(server_name_bytes)
- if not server_list:
- target = (host, 8448)
- logger.debug("No SRV record for %s, using %s", host, target)
- else:
- target = pick_server_from_list(server_list)
+ headers = headers.copy()
+
+ if not headers.hasHeader(b'host'):
+ headers.addRawHeader(b'host', res.host_header)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
- logger.info("Connecting to %s:%s", target[0], target[1])
- ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1])
+ ep = LoggingHostnameEndpoint(
+ self._reactor, res.target_host, res.target_port,
+ )
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
return ep
@@ -122,3 +172,281 @@ class MatrixFederationAgent(object):
agent.request(method, uri, headers, bodyProducer)
)
defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
+ """Helper for `request`: determine the routing for a Matrix URI
+
+ Args:
+ parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
+ parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
+ if there is no explicit port given.
+
+ lookup_well_known (bool): True if we should look up the .well-known file if
+ there is no SRV record.
+
+ Returns:
+ Deferred[_RoutingResult]
+ """
+ # check for an IP literal
+ try:
+ ip_address = IPAddress(parsed_uri.host.decode("ascii"))
+ except Exception:
+ # not an IP address
+ ip_address = None
+
+ if ip_address:
+ port = parsed_uri.port
+ if port == -1:
+ port = 8448
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=port,
+ ))
+
+ if parsed_uri.port != -1:
+ # there is an explicit port
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=parsed_uri.port,
+ ))
+
+ if lookup_well_known:
+ # try a .well-known lookup
+ well_known_server = yield self._get_well_known(parsed_uri.host)
+
+ if well_known_server:
+ # if we found a .well-known, start again, but don't do another
+ # .well-known lookup.
+
+ # parse the server name in the .well-known response into host/port.
+ # (This code is lifted from twisted.web.client.URI.fromBytes).
+ if b':' in well_known_server:
+ well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
+ try:
+ well_known_port = int(well_known_port)
+ except ValueError:
+ # the part after the colon could not be parsed as an int
+ # - we assume it is an IPv6 literal with no port (the closing
+ # ']' stops it being parsed as an int)
+ well_known_host, well_known_port = well_known_server, -1
+ else:
+ well_known_host, well_known_port = well_known_server, -1
+
+ new_uri = URI(
+ scheme=parsed_uri.scheme,
+ netloc=well_known_server,
+ host=well_known_host,
+ port=well_known_port,
+ path=parsed_uri.path,
+ params=parsed_uri.params,
+ query=parsed_uri.query,
+ fragment=parsed_uri.fragment,
+ )
+
+ res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
+ defer.returnValue(res)
+
+ # try a SRV lookup
+ service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
+ server_list = yield self._srv_resolver.resolve_service(service_name)
+
+ if not server_list:
+ target_host = parsed_uri.host
+ port = 8448
+ logger.debug(
+ "No SRV record for %s, using %s:%i",
+ parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
+ )
+ else:
+ target_host, port = pick_server_from_list(server_list)
+ logger.debug(
+ "Picked %s:%i from SRV records for %s",
+ target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
+ )
+
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=target_host,
+ target_port=port,
+ ))
+
+ @defer.inlineCallbacks
+ def _get_well_known(self, server_name):
+ """Attempt to fetch and parse a .well-known file for the given server
+
+ Args:
+ server_name (bytes): name of the server, from the requested url
+
+ Returns:
+ Deferred[bytes|None]: either the new server name, from the .well-known, or
+ None if there was no .well-known file.
+ """
+ try:
+ result = self._well_known_cache[server_name]
+ except KeyError:
+ # TODO: should we linearise so that we don't end up doing two .well-known
+ # requests for the same server in parallel?
+ with Measure(self._clock, "get_well_known"):
+ result, cache_period = yield self._do_get_well_known(server_name)
+
+ if cache_period > 0:
+ self._well_known_cache.set(server_name, result, cache_period)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _do_get_well_known(self, server_name):
+ """Actually fetch and parse a .well-known, without checking the cache
+
+ Args:
+ server_name (bytes): name of the server, from the requested url
+
+ Returns:
+ Deferred[Tuple[bytes|None|object],int]:
+ result, cache period, where result is one of:
+ - the new server name from the .well-known (as a `bytes`)
+ - None if there was no .well-known file.
+ - INVALID_WELL_KNOWN if the .well-known was invalid
+ """
+ uri = b"https://%s/.well-known/matrix/server" % (server_name, )
+ uri_str = uri.decode("ascii")
+ logger.info("Fetching %s", uri_str)
+ try:
+ response = yield make_deferred_yieldable(
+ self._well_known_agent.request(b"GET", uri),
+ )
+ body = yield make_deferred_yieldable(readBody(response))
+ if response.code != 200:
+ raise Exception("Non-200 response %s" % (response.code, ))
+
+ parsed_body = json.loads(body.decode('utf-8'))
+ logger.info("Response from .well-known: %s", parsed_body)
+ if not isinstance(parsed_body, dict):
+ raise Exception("not a dict")
+ if "m.server" not in parsed_body:
+ raise Exception("Missing key 'm.server'")
+ except Exception as e:
+ logger.info("Error fetching %s: %s", uri_str, e)
+
+ # add some randomness to the TTL to avoid a stampeding herd every hour
+ # after startup
+ cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
+ cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+ defer.returnValue((None, cache_period))
+
+ result = parsed_body["m.server"].encode("ascii")
+
+ cache_period = _cache_period_from_headers(
+ response.headers,
+ time_now=self._reactor.seconds,
+ )
+ if cache_period is None:
+ cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
+ # add some randomness to the TTL to avoid a stampeding herd every 24 hours
+ # after startup
+ cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+ else:
+ cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
+
+ defer.returnValue((result, cache_period))
+
+
+@implementer(IStreamClientEndpoint)
+class LoggingHostnameEndpoint(object):
+ """A wrapper for HostnameEndpint which logs when it connects"""
+ def __init__(self, reactor, host, port, *args, **kwargs):
+ self.host = host
+ self.port = port
+ self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+
+ def connect(self, protocol_factory):
+ logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
+ return self.ep.connect(protocol_factory)
+
+
+def _cache_period_from_headers(headers, time_now=time.time):
+ cache_controls = _parse_cache_control(headers)
+
+ if b'no-store' in cache_controls:
+ return 0
+
+ if b'max-age' in cache_controls:
+ try:
+ max_age = int(cache_controls[b'max-age'])
+ return max_age
+ except ValueError:
+ pass
+
+ expires = headers.getRawHeaders(b'expires')
+ if expires is not None:
+ try:
+ expires_date = stringToDatetime(expires[-1])
+ return expires_date - time_now()
+ except ValueError:
+ # RFC7234 says 'A cache recipient MUST interpret invalid date formats,
+ # especially the value "0", as representing a time in the past (i.e.,
+ # "already expired").
+ return 0
+
+ return None
+
+
+def _parse_cache_control(headers):
+ cache_controls = {}
+ for hdr in headers.getRawHeaders(b'cache-control', []):
+ for directive in hdr.split(b','):
+ splits = [x.strip() for x in directive.split(b'=', 1)]
+ k = splits[0].lower()
+ v = splits[1] if len(splits) > 1 else None
+ cache_controls[k] = v
+ return cache_controls
+
+
+@attr.s
+class _RoutingResult(object):
+ """The result returned by `_route_matrix_uri`.
+
+ Contains the parameters needed to direct a federation connection to a particular
+ server.
+
+ Where a SRV record points to several servers, this object contains a single server
+ chosen from the list.
+ """
+
+ host_header = attr.ib()
+ """
+ The value we should assign to the Host header (host:port from the matrix
+ URI, or .well-known).
+
+ :type: bytes
+ """
+
+ tls_server_name = attr.ib()
+ """
+ The server name we should set in the SNI (typically host, without port, from the
+ matrix URI or .well-known)
+
+ :type: bytes
+ """
+
+ target_host = attr.ib()
+ """
+ The hostname (or IP literal) we should route the TCP connection to (the target of the
+ SRV record, or the hostname from the URL/.well-known)
+
+ :type: bytes
+ """
+
+ target_port = attr.ib()
+ """
+ The port we should route the TCP connection to (the target of the SRV record, or
+ the port from the URL/.well-known, or 8448)
+
+ :type: int
+ """
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 980e912348..3c24bf3805 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -28,7 +28,7 @@ from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
-from twisted.internet import defer, protocol
+from twisted.internet import defer, protocol, task
from twisted.internet.error import DNSLookupError
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
@@ -168,7 +168,7 @@ class MatrixFederationHttpClient(object):
requests.
"""
- def __init__(self, hs):
+ def __init__(self, hs, tls_client_options_factory):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
@@ -176,7 +176,7 @@ class MatrixFederationHttpClient(object):
self.agent = MatrixFederationAgent(
hs.get_reactor(),
- hs.tls_client_options_factory,
+ tls_client_options_factory,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
@@ -255,7 +255,6 @@ class MatrixFederationHttpClient(object):
headers_dict = {
b"User-Agent": [self.version_string_bytes],
- b"Host": [destination_bytes],
}
with limiter:
@@ -287,7 +286,7 @@ class MatrixFederationHttpClient(object):
json,
)
data = encode_canonical_json(json)
- producer = FileBodyProducer(
+ producer = QuieterFileBodyProducer(
BytesIO(data),
cooperator=self._cooperator,
)
@@ -840,3 +839,16 @@ def encode_query_args(args):
query_bytes = urllib.parse.urlencode(encoded_args, True)
return query_bytes.encode('utf8')
+
+
+class QuieterFileBodyProducer(FileBodyProducer):
+ """Wrapper for FileBodyProducer that avoids CRITICAL errors when the connection drops.
+
+ Workaround for https://github.com/matrix-org/synapse/issues/4003 /
+ https://twistedmatrix.com/trac/ticket/6528
+ """
+ def stopProducing(self):
+ try:
+ FileBodyProducer.stopProducing(self)
+ except task.TaskStopped:
+ pass
|