diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index a3f9e4f67c..d36bcd6336 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -15,8 +15,10 @@
# limitations under the License.
import re
+from twisted.internet import task
from twisted.internet.defer import CancelledError
from twisted.python import failure
+from twisted.web.client import FileBodyProducer
from synapse.api.errors import SynapseError
@@ -47,3 +49,16 @@ def redact_uri(uri):
r'\1<redacted>\3',
uri
)
+
+
+class QuieterFileBodyProducer(FileBodyProducer):
+ """Wrapper for FileBodyProducer that avoids CRITICAL errors when the connection drops.
+
+ Workaround for https://github.com/matrix-org/synapse/issues/4003 /
+ https://twistedmatrix.com/trac/ticket/6528
+ """
+ def stopProducing(self):
+ try:
+ FileBodyProducer.stopProducing(self)
+ except task.TaskStopped:
+ pass
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 3d05f83b8c..ad454f4964 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -15,34 +15,36 @@
# limitations under the License.
import logging
+from io import BytesIO
from six import text_type
from six.moves import urllib
import treq
from canonicaljson import encode_canonical_json, json
+from netaddr import IPAddress
from prometheus_client import Counter
+from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
-from twisted.internet import defer, protocol, reactor, ssl
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.web._newclient import ResponseDone
-from twisted.web.client import (
- Agent,
- BrowserLikeRedirectAgent,
- ContentDecoderAgent,
- GzipDecoder,
- HTTPConnectionPool,
- PartialDownloadError,
- readBody,
+from twisted.internet import defer, protocol, ssl
+from twisted.internet.interfaces import (
+ IReactorPluggableNameResolver,
+ IResolutionReceiver,
)
+from twisted.python.failure import Failure
+from twisted.web._newclient import ResponseDone
+from twisted.web.client import Agent, HTTPConnectionPool, PartialDownloadError, readBody
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from synapse.api.errors import Codes, HttpResponseException, SynapseError
-from synapse.http import cancelled_to_request_timed_out_error, redact_uri
-from synapse.http.endpoint import SpiderEndpoint
+from synapse.http import (
+ QuieterFileBodyProducer,
+ cancelled_to_request_timed_out_error,
+ redact_uri,
+)
from synapse.util.async_helpers import timeout_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
@@ -50,8 +52,125 @@ from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
-incoming_responses_counter = Counter("synapse_http_client_responses", "",
- ["method", "code"])
+incoming_responses_counter = Counter(
+ "synapse_http_client_responses", "", ["method", "code"]
+)
+
+
+def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
+ """
+ Args:
+ ip_address (netaddr.IPAddress)
+ ip_whitelist (netaddr.IPSet)
+ ip_blacklist (netaddr.IPSet)
+ """
+ if ip_address in ip_blacklist:
+ if ip_whitelist is None or ip_address not in ip_whitelist:
+ return True
+ return False
+
+
+class IPBlacklistingResolver(object):
+ """
+ A proxy for reactor.nameResolver which only produces non-blacklisted IP
+ addresses, preventing DNS rebinding attacks on URL preview.
+ """
+
+ def __init__(self, reactor, ip_whitelist, ip_blacklist):
+ """
+ Args:
+ reactor (twisted.internet.reactor)
+ ip_whitelist (netaddr.IPSet)
+ ip_blacklist (netaddr.IPSet)
+ """
+ self._reactor = reactor
+ self._ip_whitelist = ip_whitelist
+ self._ip_blacklist = ip_blacklist
+
+ def resolveHostName(self, recv, hostname, portNumber=0):
+
+ r = recv()
+ d = defer.Deferred()
+ addresses = []
+
+ @provider(IResolutionReceiver)
+ class EndpointReceiver(object):
+ @staticmethod
+ def resolutionBegan(resolutionInProgress):
+ pass
+
+ @staticmethod
+ def addressResolved(address):
+ ip_address = IPAddress(address.host)
+
+ if check_against_blacklist(
+ ip_address, self._ip_whitelist, self._ip_blacklist
+ ):
+ logger.info(
+ "Dropped %s from DNS resolution to %s" % (ip_address, hostname)
+ )
+ raise SynapseError(403, "IP address blocked by IP blacklist entry")
+
+ addresses.append(address)
+
+ @staticmethod
+ def resolutionComplete():
+ d.callback(addresses)
+
+ self._reactor.nameResolver.resolveHostName(
+ EndpointReceiver, hostname, portNumber=portNumber
+ )
+
+ def _callback(addrs):
+ r.resolutionBegan(None)
+ for i in addrs:
+ r.addressResolved(i)
+ r.resolutionComplete()
+
+ d.addCallback(_callback)
+
+ return r
+
+
+class BlacklistingAgentWrapper(Agent):
+ """
+ An Agent wrapper which will prevent access to IP addresses being accessed
+ directly (without an IP address lookup).
+ """
+
+ def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
+ """
+ Args:
+ agent (twisted.web.client.Agent): The Agent to wrap.
+ reactor (twisted.internet.reactor)
+ ip_whitelist (netaddr.IPSet)
+ ip_blacklist (netaddr.IPSet)
+ """
+ self._agent = agent
+ self._ip_whitelist = ip_whitelist
+ self._ip_blacklist = ip_blacklist
+
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ h = urllib.parse.urlparse(uri.decode('ascii'))
+
+ try:
+ ip_address = IPAddress(h.hostname)
+
+ if check_against_blacklist(
+ ip_address, self._ip_whitelist, self._ip_blacklist
+ ):
+ logger.info(
+ "Blocking access to %s because of blacklist" % (ip_address,)
+ )
+ e = SynapseError(403, "IP address blocked by IP blacklist entry")
+ return defer.fail(Failure(e))
+ except Exception:
+ # Not an IP
+ pass
+
+ return self._agent.request(
+ method, uri, headers=headers, bodyProducer=bodyProducer
+ )
class SimpleHttpClient(object):
@@ -59,14 +178,54 @@ class SimpleHttpClient(object):
A simple, no-frills HTTP client with methods that wrap up common ways of
using HTTP in Matrix
"""
- def __init__(self, hs):
+
+ def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
+ """
+ Args:
+ hs (synapse.server.HomeServer)
+ treq_args (dict): Extra keyword arguments to be given to treq.request.
+ ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
+ we may not request.
+ ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
+ request if it were otherwise caught in a blacklist.
+ """
self.hs = hs
- pool = HTTPConnectionPool(reactor)
+ self._ip_whitelist = ip_whitelist
+ self._ip_blacklist = ip_blacklist
+ self._extra_treq_args = treq_args
+
+ self.user_agent = hs.version_string
+ self.clock = hs.get_clock()
+ if hs.config.user_agent_suffix:
+ self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
+
+ self.user_agent = self.user_agent.encode('ascii')
+
+ if self._ip_blacklist:
+ real_reactor = hs.get_reactor()
+ # If we have an IP blacklist, we need to use a DNS resolver which
+ # filters out blacklisted IP addresses, to prevent DNS rebinding.
+ nameResolver = IPBlacklistingResolver(
+ real_reactor, self._ip_whitelist, self._ip_blacklist
+ )
+
+ @implementer(IReactorPluggableNameResolver)
+ class Reactor(object):
+ def __getattr__(_self, attr):
+ if attr == "nameResolver":
+ return nameResolver
+ else:
+ return getattr(real_reactor, attr)
+
+ self.reactor = Reactor()
+ else:
+ self.reactor = hs.get_reactor()
# the pusher makes lots of concurrent SSL connections to sygnal, and
- # tends to do so in batches, so we need to allow the pool to keep lots
- # of idle connections around.
+ # tends to do so in batches, so we need to allow the pool to keep
+ # lots of idle connections around.
+ pool = HTTPConnectionPool(self.reactor)
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60
@@ -74,20 +233,35 @@ class SimpleHttpClient(object):
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(
- reactor,
+ self.reactor,
connectTimeout=15,
- contextFactory=hs.get_http_client_context_factory(),
+ contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
)
- self.user_agent = hs.version_string
- self.clock = hs.get_clock()
- if hs.config.user_agent_suffix:
- self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
- self.user_agent = self.user_agent.encode('ascii')
+ if self._ip_blacklist:
+ # If we have an IP blacklist, we then install the blacklisting Agent
+ # which prevents direct access to IP addresses, that are not caught
+ # by the DNS resolution.
+ self.agent = BlacklistingAgentWrapper(
+ self.agent,
+ self.reactor,
+ ip_whitelist=self._ip_whitelist,
+ ip_blacklist=self._ip_blacklist,
+ )
@defer.inlineCallbacks
- def request(self, method, uri, data=b'', headers=None):
+ def request(self, method, uri, data=None, headers=None):
+ """
+ Args:
+ method (str): HTTP method to use.
+ uri (str): URI to query.
+ data (bytes): Data to send in the request body, if applicable.
+ headers (t.w.http_headers.Headers): Request headers.
+
+ Raises:
+ SynapseError: If the IP is blacklisted.
+ """
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.labels(method).inc()
@@ -96,26 +270,39 @@ class SimpleHttpClient(object):
logger.info("Sending request %s %s", method, redact_uri(uri))
try:
+ body_producer = None
+ if data is not None:
+ body_producer = QuieterFileBodyProducer(BytesIO(data))
+
request_deferred = treq.request(
- method, uri, agent=self.agent, data=data, headers=headers
+ method,
+ uri,
+ agent=self.agent,
+ data=body_producer,
+ headers=headers,
+ **self._extra_treq_args
)
request_deferred = timeout_deferred(
- request_deferred, 60, self.hs.get_reactor(),
+ request_deferred,
+ 60,
+ self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
response = yield make_deferred_yieldable(request_deferred)
incoming_responses_counter.labels(method, response.code).inc()
logger.info(
- "Received response to %s %s: %s",
- method, redact_uri(uri), response.code
+ "Received response to %s %s: %s", method, redact_uri(uri), response.code
)
defer.returnValue(response)
except Exception as e:
incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
- method, redact_uri(uri), type(e).__name__, e.args[0]
+ method,
+ redact_uri(uri),
+ type(e).__name__,
+ e.args[0],
)
raise
@@ -140,8 +327,9 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
- query_bytes = urllib.parse.urlencode(
- encode_urlencode_args(args), True).encode("utf8")
+ query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
+ "utf8"
+ )
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -151,15 +339,13 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
response = yield self.request(
- "POST",
- uri,
- headers=Headers(actual_headers),
- data=query_bytes
+ "POST", uri, headers=Headers(actual_headers), data=query_bytes
)
+ body = yield make_deferred_yieldable(readBody(response))
+
if 200 <= response.code < 300:
- body = yield make_deferred_yieldable(treq.json_content(response))
- defer.returnValue(body)
+ defer.returnValue(json.loads(body))
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -193,10 +379,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
response = yield self.request(
- "POST",
- uri,
- headers=Headers(actual_headers),
- data=json_str
+ "POST", uri, headers=Headers(actual_headers), data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@@ -264,10 +447,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
response = yield self.request(
- "PUT",
- uri,
- headers=Headers(actual_headers),
- data=json_str
+ "PUT", uri, headers=Headers(actual_headers), data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@@ -299,17 +479,11 @@ class SimpleHttpClient(object):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
- actual_headers = {
- b"User-Agent": [self.user_agent],
- }
+ actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
- response = yield self.request(
- "GET",
- uri,
- headers=Headers(actual_headers),
- )
+ response = yield self.request("GET", uri, headers=Headers(actual_headers))
body = yield make_deferred_yieldable(readBody(response))
@@ -334,22 +508,18 @@ class SimpleHttpClient(object):
headers, absolute URI of the response and HTTP response code.
"""
- actual_headers = {
- b"User-Agent": [self.user_agent],
- }
+ actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
- response = yield self.request(
- "GET",
- url,
- headers=Headers(actual_headers),
- )
+ response = yield self.request("GET", url, headers=Headers(actual_headers))
resp_headers = dict(response.headers.getAllRawHeaders())
- if (b'Content-Length' in resp_headers and
- int(resp_headers[b'Content-Length']) > max_size):
+ if (
+ b'Content-Length' in resp_headers
+ and int(resp_headers[b'Content-Length'][0]) > max_size
+ ):
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@@ -359,26 +529,20 @@ class SimpleHttpClient(object):
if response.code > 299:
logger.warn("Got %d when downloading %s" % (response.code, url))
- raise SynapseError(
- 502,
- "Got error %d" % (response.code,),
- Codes.UNKNOWN,
- )
+ raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
# TODO: if our Content-Type is HTML or something, just read the first
# N bytes into RAM rather than saving it all to disk only to read it
# straight back in again
try:
- length = yield make_deferred_yieldable(_readBodyToFile(
- response, output_stream, max_size,
- ))
+ length = yield make_deferred_yieldable(
+ _readBodyToFile(response, output_stream, max_size)
+ )
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
- 502,
- ("Failed to download remote body: %s" % e),
- Codes.UNKNOWN,
+ 502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN
)
defer.returnValue(
@@ -387,13 +551,14 @@ class SimpleHttpClient(object):
resp_headers,
response.request.absoluteURI.decode('ascii'),
response.code,
- ),
+ )
)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
+
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
self.stream = stream
@@ -405,11 +570,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- ))
+ self.deferred.errback(
+ SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (self.max_size,),
+ Codes.TOO_LARGE,
+ )
+ )
self.deferred = defer.Deferred()
self.transport.loseConnection()
@@ -427,6 +594,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
+
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
@@ -449,10 +617,12 @@ class CaptchaServerHttpClient(SimpleHttpClient):
"POST",
url,
data=query_bytes,
- headers=Headers({
- b"Content-Type": [b"application/x-www-form-urlencoded"],
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(
+ {
+ b"Content-Type": [b"application/x-www-form-urlencoded"],
+ b"User-Agent": [self.user_agent],
+ }
+ ),
)
try:
@@ -463,57 +633,6 @@ class CaptchaServerHttpClient(SimpleHttpClient):
defer.returnValue(e.response)
-class SpiderEndpointFactory(object):
- def __init__(self, hs):
- self.blacklist = hs.config.url_preview_ip_range_blacklist
- self.whitelist = hs.config.url_preview_ip_range_whitelist
- self.policyForHTTPS = hs.get_http_client_context_factory()
-
- def endpointForURI(self, uri):
- logger.info("Getting endpoint for %s", uri.toBytes())
-
- if uri.scheme == b"http":
- endpoint_factory = HostnameEndpoint
- elif uri.scheme == b"https":
- tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
-
- def endpoint_factory(reactor, host, port, **kw):
- return wrapClientTLS(
- tlsCreator,
- HostnameEndpoint(reactor, host, port, **kw))
- else:
- logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
- return None
- return SpiderEndpoint(
- reactor, uri.host, uri.port, self.blacklist, self.whitelist,
- endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15),
- )
-
-
-class SpiderHttpClient(SimpleHttpClient):
- """
- Separate HTTP client for spidering arbitrary URLs.
- Special in that it follows retries and has a UA that looks
- like a browser.
-
- used by the preview_url endpoint in the content repo.
- """
- def __init__(self, hs):
- SimpleHttpClient.__init__(self, hs)
- # clobber the base class's agent and UA:
- self.agent = ContentDecoderAgent(
- BrowserLikeRedirectAgent(
- Agent.usingEndpointFactory(
- reactor,
- SpiderEndpointFactory(hs)
- )
- ), [(b'gzip', GzipDecoder)]
- )
- # We could look like Chrome:
- # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
- # Chrome Safari" % hs.version_string)
-
-
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 91025037a3..cd79ebab62 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,30 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
import logging
-import random
import re
-import time
-
-from twisted.internet import defer
-from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.error import ConnectError
-from twisted.names import client, dns
-from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__)
-SERVER_CACHE = {}
-
-# our record of an individual server which can be tried to reach a destination.
-#
-# "host" is the hostname acquired from the SRV record. Except when there's
-# no SRV record, in which case it is the original hostname.
-_Server = collections.namedtuple(
- "_Server", "priority weight host port expires"
-)
-
def parse_server_name(server_name):
"""Split a server name into host/port parts.
@@ -100,299 +81,3 @@ def parse_and_validate_server_name(server_name):
))
return host, port
-
-
-def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
- timeout=None):
- """Construct an endpoint for the given matrix destination.
-
- Args:
- reactor: Twisted reactor.
- destination (unicode): The name of the server to connect to.
- tls_client_options_factory
- (synapse.crypto.context_factory.ClientTLSOptionsFactory):
- Factory which generates TLS options for client connections.
- timeout (int): connection timeout in seconds
- """
-
- domain, port = parse_server_name(destination)
-
- endpoint_kw_args = {}
-
- if timeout is not None:
- endpoint_kw_args.update(timeout=timeout)
-
- if tls_client_options_factory is None:
- transport_endpoint = HostnameEndpoint
- default_port = 8008
- else:
- # the SNI string should be the same as the Host header, minus the port.
- # as per https://github.com/matrix-org/synapse/issues/2525#issuecomment-336896777,
- # the Host header and SNI should therefore be the server_name of the remote
- # server.
- tls_options = tls_client_options_factory.get_options(domain)
-
- def transport_endpoint(reactor, host, port, timeout):
- return wrapClientTLS(
- tls_options,
- HostnameEndpoint(reactor, host, port, timeout=timeout),
- )
- default_port = 8448
-
- if port is None:
- return _WrappingEndpointFac(SRVClientEndpoint(
- reactor, "matrix", domain, protocol="tcp",
- default_port=default_port, endpoint=transport_endpoint,
- endpoint_kw_args=endpoint_kw_args
- ), reactor)
- else:
- return _WrappingEndpointFac(transport_endpoint(
- reactor, domain, port, **endpoint_kw_args
- ), reactor)
-
-
-class _WrappingEndpointFac(object):
- def __init__(self, endpoint_fac, reactor):
- self.endpoint_fac = endpoint_fac
- self.reactor = reactor
-
- @defer.inlineCallbacks
- def connect(self, protocolFactory):
- conn = yield self.endpoint_fac.connect(protocolFactory)
- conn = _WrappedConnection(conn, self.reactor)
- defer.returnValue(conn)
-
-
-class _WrappedConnection(object):
- """Wraps a connection and calls abort on it if it hasn't seen any action
- for 2.5-3 minutes.
- """
- __slots__ = ["conn", "last_request"]
-
- def __init__(self, conn, reactor):
- object.__setattr__(self, "conn", conn)
- object.__setattr__(self, "last_request", time.time())
- self._reactor = reactor
-
- def __getattr__(self, name):
- return getattr(self.conn, name)
-
- def __setattr__(self, name, value):
- setattr(self.conn, name, value)
-
- def _time_things_out_maybe(self):
- # We use a slightly shorter timeout here just in case the callLater is
- # triggered early. Paranoia ftw.
- # TODO: Cancel the previous callLater rather than comparing time.time()?
- if time.time() - self.last_request >= 2.5 * 60:
- self.abort()
- # Abort the underlying TLS connection. The abort() method calls
- # loseConnection() on the TLS connection which tries to
- # shutdown the connection cleanly. We call abortConnection()
- # since that will promptly close the TLS connection.
- #
- # In Twisted >18.4; the TLS connection will be None if it has closed
- # which will make abortConnection() throw. Check that the TLS connection
- # is not None before trying to close it.
- if self.transport.getHandle() is not None:
- self.transport.abortConnection()
-
- def request(self, request):
- self.last_request = time.time()
-
- # Time this connection out if we haven't send a request in the last
- # N minutes
- # TODO: Cancel the previous callLater?
- self._reactor.callLater(3 * 60, self._time_things_out_maybe)
-
- d = self.conn.request(request)
-
- def update_request_time(res):
- self.last_request = time.time()
- # TODO: Cancel the previous callLater?
- self._reactor.callLater(3 * 60, self._time_things_out_maybe)
- return res
-
- d.addCallback(update_request_time)
-
- return d
-
-
-class SpiderEndpoint(object):
- """An endpoint which refuses to connect to blacklisted IP addresses
- Implements twisted.internet.interfaces.IStreamClientEndpoint.
- """
- def __init__(self, reactor, host, port, blacklist, whitelist,
- endpoint=HostnameEndpoint, endpoint_kw_args={}):
- self.reactor = reactor
- self.host = host
- self.port = port
- self.blacklist = blacklist
- self.whitelist = whitelist
- self.endpoint = endpoint
- self.endpoint_kw_args = endpoint_kw_args
-
- @defer.inlineCallbacks
- def connect(self, protocolFactory):
- address = yield self.reactor.resolve(self.host)
-
- from netaddr import IPAddress
- ip_address = IPAddress(address)
-
- if ip_address in self.blacklist:
- if self.whitelist is None or ip_address not in self.whitelist:
- raise ConnectError(
- "Refusing to spider blacklisted IP address %s" % address
- )
-
- logger.info("Connecting to %s:%s", address, self.port)
- endpoint = self.endpoint(
- self.reactor, address, self.port, **self.endpoint_kw_args
- )
- connection = yield endpoint.connect(protocolFactory)
- defer.returnValue(connection)
-
-
-class SRVClientEndpoint(object):
- """An endpoint which looks up SRV records for a service.
- Cycles through the list of servers starting with each call to connect
- picking the next server.
- Implements twisted.internet.interfaces.IStreamClientEndpoint.
- """
-
- def __init__(self, reactor, service, domain, protocol="tcp",
- default_port=None, endpoint=HostnameEndpoint,
- endpoint_kw_args={}):
- self.reactor = reactor
- self.service_name = "_%s._%s.%s" % (service, protocol, domain)
-
- if default_port is not None:
- self.default_server = _Server(
- host=domain,
- port=default_port,
- priority=0,
- weight=0,
- expires=0,
- )
- else:
- self.default_server = None
-
- self.endpoint = endpoint
- self.endpoint_kw_args = endpoint_kw_args
-
- self.servers = None
- self.used_servers = None
-
- @defer.inlineCallbacks
- def fetch_servers(self):
- self.used_servers = []
- self.servers = yield resolve_service(self.service_name)
-
- def pick_server(self):
- if not self.servers:
- if self.used_servers:
- self.servers = self.used_servers
- self.used_servers = []
- self.servers.sort()
- elif self.default_server:
- return self.default_server
- else:
- raise ConnectError(
- "No server available for %s" % self.service_name
- )
-
- # look for all servers with the same priority
- min_priority = self.servers[0].priority
- weight_indexes = list(
- (index, server.weight + 1)
- for index, server in enumerate(self.servers)
- if server.priority == min_priority
- )
-
- total_weight = sum(weight for index, weight in weight_indexes)
- target_weight = random.randint(0, total_weight)
- for index, weight in weight_indexes:
- target_weight -= weight
- if target_weight <= 0:
- server = self.servers[index]
- # XXX: this looks totally dubious:
- #
- # (a) we never reuse a server until we have been through
- # all of the servers at the same priority, so if the
- # weights are A: 100, B:1, we always do ABABAB instead of
- # AAAA...AAAB (approximately).
- #
- # (b) After using all the servers at the lowest priority,
- # we move onto the next priority. We should only use the
- # second priority if servers at the top priority are
- # unreachable.
- #
- del self.servers[index]
- self.used_servers.append(server)
- return server
-
- @defer.inlineCallbacks
- def connect(self, protocolFactory):
- if self.servers is None:
- yield self.fetch_servers()
- server = self.pick_server()
- logger.info("Connecting to %s:%s", server.host, server.port)
- endpoint = self.endpoint(
- self.reactor, server.host, server.port, **self.endpoint_kw_args
- )
- connection = yield endpoint.connect(protocolFactory)
- defer.returnValue(connection)
-
-
-@defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- if all(s.expires > int(clock.time()) for s in cache_entry):
- servers = list(cache_entry)
- defer.returnValue(servers)
-
- servers = []
-
- try:
- try:
- answers, _, _ = yield dns_client.lookupService(service_name)
- except DNSNameError:
- defer.returnValue([])
-
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
- raise ConnectError("Service %s unavailable" % service_name)
-
- for answer in answers:
- if answer.type != dns.SRV or not answer.payload:
- continue
-
- payload = answer.payload
-
- servers.append(_Server(
- host=str(payload.target),
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + answer.ttl,
- ))
-
- servers.sort()
- cache[service_name] = list(servers)
- except DomainError as e:
- # We failed to resolve the name (other than a NameError)
- # Try something in the cache, else rereaise
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
- )
- servers = list(cache_entry)
- else:
- raise e
-
- defer.returnValue(servers)
diff --git a/synapse/http/federation/__init__.py b/synapse/http/federation/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/synapse/http/federation/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
new file mode 100644
index 0000000000..384d8a37a2
--- /dev/null
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -0,0 +1,452 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+import random
+import time
+
+import attr
+from netaddr import IPAddress
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
+from twisted.web.http import stringToDatetime
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IAgent
+
+from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.util import Clock
+from synapse.util.caches.ttlcache import TTLCache
+from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.metrics import Measure
+
+# period to cache .well-known results for by default
+WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
+
+# jitter to add to the .well-known default cache ttl
+WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
+
+# period to cache failure to fetch .well-known for
+WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
+
+# cap for .well-known cache period
+WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
+
+logger = logging.getLogger(__name__)
+well_known_cache = TTLCache('well-known')
+
+
+@implementer(IAgent)
+class MatrixFederationAgent(object):
+ """An Agent-like thing which provides a `request` method which will look up a matrix
+ server and send an HTTP request to it.
+
+ Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
+
+ Args:
+ reactor (IReactor): twisted reactor to use for underlying requests
+
+ tls_client_options_factory (ClientTLSOptionsFactory|None):
+ factory to use for fetching client tls options, or none to disable TLS.
+
+ _well_known_tls_policy (IPolicyForHTTPS|None):
+ TLS policy to use for fetching .well-known files. None to use a default
+ (browser-like) implementation.
+
+ srv_resolver (SrvResolver|None):
+ SRVResolver impl to use for looking up SRV records. None to use a default
+ implementation.
+ """
+
+ def __init__(
+ self, reactor, tls_client_options_factory,
+ _well_known_tls_policy=None,
+ _srv_resolver=None,
+ _well_known_cache=well_known_cache,
+ ):
+ self._reactor = reactor
+ self._clock = Clock(reactor)
+
+ self._tls_client_options_factory = tls_client_options_factory
+ if _srv_resolver is None:
+ _srv_resolver = SrvResolver()
+ self._srv_resolver = _srv_resolver
+
+ self._pool = HTTPConnectionPool(reactor)
+ self._pool.retryAutomatically = False
+ self._pool.maxPersistentPerHost = 5
+ self._pool.cachedConnectionTimeout = 2 * 60
+
+ agent_args = {}
+ if _well_known_tls_policy is not None:
+ # the param is called 'contextFactory', but actually passing a
+ # contextfactory is deprecated, and it expects an IPolicyForHTTPS.
+ agent_args['contextFactory'] = _well_known_tls_policy
+ _well_known_agent = RedirectAgent(
+ Agent(self._reactor, pool=self._pool, **agent_args),
+ )
+ self._well_known_agent = _well_known_agent
+
+ # our cache of .well-known lookup results, mapping from server name
+ # to delegated name. The values can be:
+ # `bytes`: a valid server-name
+ # `None`: there is no (valid) .well-known here
+ self._well_known_cache = _well_known_cache
+
+ @defer.inlineCallbacks
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Args:
+ method (bytes): HTTP method: GET/POST/etc
+
+ uri (bytes): Absolute URI to be retrieved
+
+ headers (twisted.web.http_headers.Headers|None):
+ HTTP headers to send with the request, or None to
+ send no extra headers.
+
+ bodyProducer (twisted.web.iweb.IBodyProducer|None):
+ An object which can generate bytes to make up the
+ body of this request (for example, the properly encoded contents of
+ a file for a file upload). Or None if the request is to have
+ no body.
+
+ Returns:
+ Deferred[twisted.web.iweb.IResponse]:
+ fires when the header of the response has been received (regardless of the
+ response status code). Fails if there is any problem which prevents that
+ response from being received (including problems that prevent the request
+ from being sent).
+ """
+ parsed_uri = URI.fromBytes(uri, defaultPort=-1)
+ res = yield self._route_matrix_uri(parsed_uri)
+
+ # set up the TLS connection params
+ #
+ # XXX disabling TLS is really only supported here for the benefit of the
+ # unit tests. We should make the UTs cope with TLS rather than having to make
+ # the code support the unit tests.
+ if self._tls_client_options_factory is None:
+ tls_options = None
+ else:
+ tls_options = self._tls_client_options_factory.get_options(
+ res.tls_server_name.decode("ascii")
+ )
+
+ # make sure that the Host header is set correctly
+ if headers is None:
+ headers = Headers()
+ else:
+ headers = headers.copy()
+
+ if not headers.hasHeader(b'host'):
+ headers.addRawHeader(b'host', res.host_header)
+
+ class EndpointFactory(object):
+ @staticmethod
+ def endpointForURI(_uri):
+ ep = LoggingHostnameEndpoint(
+ self._reactor, res.target_host, res.target_port,
+ )
+ if tls_options is not None:
+ ep = wrapClientTLS(tls_options, ep)
+ return ep
+
+ agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
+ res = yield make_deferred_yieldable(
+ agent.request(method, uri, headers, bodyProducer)
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
+ """Helper for `request`: determine the routing for a Matrix URI
+
+ Args:
+ parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
+ parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
+ if there is no explicit port given.
+
+ lookup_well_known (bool): True if we should look up the .well-known file if
+ there is no SRV record.
+
+ Returns:
+ Deferred[_RoutingResult]
+ """
+ # check for an IP literal
+ try:
+ ip_address = IPAddress(parsed_uri.host.decode("ascii"))
+ except Exception:
+ # not an IP address
+ ip_address = None
+
+ if ip_address:
+ port = parsed_uri.port
+ if port == -1:
+ port = 8448
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=port,
+ ))
+
+ if parsed_uri.port != -1:
+ # there is an explicit port
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=parsed_uri.port,
+ ))
+
+ if lookup_well_known:
+ # try a .well-known lookup
+ well_known_server = yield self._get_well_known(parsed_uri.host)
+
+ if well_known_server:
+ # if we found a .well-known, start again, but don't do another
+ # .well-known lookup.
+
+ # parse the server name in the .well-known response into host/port.
+ # (This code is lifted from twisted.web.client.URI.fromBytes).
+ if b':' in well_known_server:
+ well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
+ try:
+ well_known_port = int(well_known_port)
+ except ValueError:
+ # the part after the colon could not be parsed as an int
+ # - we assume it is an IPv6 literal with no port (the closing
+ # ']' stops it being parsed as an int)
+ well_known_host, well_known_port = well_known_server, -1
+ else:
+ well_known_host, well_known_port = well_known_server, -1
+
+ new_uri = URI(
+ scheme=parsed_uri.scheme,
+ netloc=well_known_server,
+ host=well_known_host,
+ port=well_known_port,
+ path=parsed_uri.path,
+ params=parsed_uri.params,
+ query=parsed_uri.query,
+ fragment=parsed_uri.fragment,
+ )
+
+ res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
+ defer.returnValue(res)
+
+ # try a SRV lookup
+ service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
+ server_list = yield self._srv_resolver.resolve_service(service_name)
+
+ if not server_list:
+ target_host = parsed_uri.host
+ port = 8448
+ logger.debug(
+ "No SRV record for %s, using %s:%i",
+ parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
+ )
+ else:
+ target_host, port = pick_server_from_list(server_list)
+ logger.debug(
+ "Picked %s:%i from SRV records for %s",
+ target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
+ )
+
+ defer.returnValue(_RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=target_host,
+ target_port=port,
+ ))
+
+ @defer.inlineCallbacks
+ def _get_well_known(self, server_name):
+ """Attempt to fetch and parse a .well-known file for the given server
+
+ Args:
+ server_name (bytes): name of the server, from the requested url
+
+ Returns:
+ Deferred[bytes|None]: either the new server name, from the .well-known, or
+ None if there was no .well-known file.
+ """
+ try:
+ result = self._well_known_cache[server_name]
+ except KeyError:
+ # TODO: should we linearise so that we don't end up doing two .well-known
+ # requests for the same server in parallel?
+ with Measure(self._clock, "get_well_known"):
+ result, cache_period = yield self._do_get_well_known(server_name)
+
+ if cache_period > 0:
+ self._well_known_cache.set(server_name, result, cache_period)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _do_get_well_known(self, server_name):
+ """Actually fetch and parse a .well-known, without checking the cache
+
+ Args:
+ server_name (bytes): name of the server, from the requested url
+
+ Returns:
+ Deferred[Tuple[bytes|None|object],int]:
+ result, cache period, where result is one of:
+ - the new server name from the .well-known (as a `bytes`)
+ - None if there was no .well-known file.
+ - INVALID_WELL_KNOWN if the .well-known was invalid
+ """
+ uri = b"https://%s/.well-known/matrix/server" % (server_name, )
+ uri_str = uri.decode("ascii")
+ logger.info("Fetching %s", uri_str)
+ try:
+ response = yield make_deferred_yieldable(
+ self._well_known_agent.request(b"GET", uri),
+ )
+ body = yield make_deferred_yieldable(readBody(response))
+ if response.code != 200:
+ raise Exception("Non-200 response %s" % (response.code, ))
+
+ parsed_body = json.loads(body.decode('utf-8'))
+ logger.info("Response from .well-known: %s", parsed_body)
+ if not isinstance(parsed_body, dict):
+ raise Exception("not a dict")
+ if "m.server" not in parsed_body:
+ raise Exception("Missing key 'm.server'")
+ except Exception as e:
+ logger.info("Error fetching %s: %s", uri_str, e)
+
+ # add some randomness to the TTL to avoid a stampeding herd every hour
+ # after startup
+ cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
+ cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+ defer.returnValue((None, cache_period))
+
+ result = parsed_body["m.server"].encode("ascii")
+
+ cache_period = _cache_period_from_headers(
+ response.headers,
+ time_now=self._reactor.seconds,
+ )
+ if cache_period is None:
+ cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
+ # add some randomness to the TTL to avoid a stampeding herd every 24 hours
+ # after startup
+ cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+ else:
+ cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
+
+ defer.returnValue((result, cache_period))
+
+
+@implementer(IStreamClientEndpoint)
+class LoggingHostnameEndpoint(object):
+ """A wrapper for HostnameEndpint which logs when it connects"""
+ def __init__(self, reactor, host, port, *args, **kwargs):
+ self.host = host
+ self.port = port
+ self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+
+ def connect(self, protocol_factory):
+ logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
+ return self.ep.connect(protocol_factory)
+
+
+def _cache_period_from_headers(headers, time_now=time.time):
+ cache_controls = _parse_cache_control(headers)
+
+ if b'no-store' in cache_controls:
+ return 0
+
+ if b'max-age' in cache_controls:
+ try:
+ max_age = int(cache_controls[b'max-age'])
+ return max_age
+ except ValueError:
+ pass
+
+ expires = headers.getRawHeaders(b'expires')
+ if expires is not None:
+ try:
+ expires_date = stringToDatetime(expires[-1])
+ return expires_date - time_now()
+ except ValueError:
+ # RFC7234 says 'A cache recipient MUST interpret invalid date formats,
+ # especially the value "0", as representing a time in the past (i.e.,
+ # "already expired").
+ return 0
+
+ return None
+
+
+def _parse_cache_control(headers):
+ cache_controls = {}
+ for hdr in headers.getRawHeaders(b'cache-control', []):
+ for directive in hdr.split(b','):
+ splits = [x.strip() for x in directive.split(b'=', 1)]
+ k = splits[0].lower()
+ v = splits[1] if len(splits) > 1 else None
+ cache_controls[k] = v
+ return cache_controls
+
+
+@attr.s
+class _RoutingResult(object):
+ """The result returned by `_route_matrix_uri`.
+
+ Contains the parameters needed to direct a federation connection to a particular
+ server.
+
+ Where a SRV record points to several servers, this object contains a single server
+ chosen from the list.
+ """
+
+ host_header = attr.ib()
+ """
+ The value we should assign to the Host header (host:port from the matrix
+ URI, or .well-known).
+
+ :type: bytes
+ """
+
+ tls_server_name = attr.ib()
+ """
+ The server name we should set in the SNI (typically host, without port, from the
+ matrix URI or .well-known)
+
+ :type: bytes
+ """
+
+ target_host = attr.ib()
+ """
+ The hostname (or IP literal) we should route the TCP connection to (the target of the
+ SRV record, or the hostname from the URL/.well-known)
+
+ :type: bytes
+ """
+
+ target_port = attr.ib()
+ """
+ The port we should route the TCP connection to (the target of the SRV record, or
+ the port from the URL/.well-known, or 8448)
+
+ :type: int
+ """
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
new file mode 100644
index 0000000000..71830c549d
--- /dev/null
+++ b/synapse/http/federation/srv_resolver.py
@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import random
+import time
+
+import attr
+
+from twisted.internet import defer
+from twisted.internet.error import ConnectError
+from twisted.names import client, dns
+from twisted.names.error import DNSNameError, DomainError
+
+from synapse.util.logcontext import make_deferred_yieldable
+
+logger = logging.getLogger(__name__)
+
+SERVER_CACHE = {}
+
+
+@attr.s
+class Server(object):
+ """
+ Our record of an individual server which can be tried to reach a destination.
+
+ Attributes:
+ host (bytes): target hostname
+ port (int):
+ priority (int):
+ weight (int):
+ expires (int): when the cache should expire this record - in *seconds* since
+ the epoch
+ """
+ host = attr.ib()
+ port = attr.ib()
+ priority = attr.ib(default=0)
+ weight = attr.ib(default=0)
+ expires = attr.ib(default=0)
+
+
+def pick_server_from_list(server_list):
+ """Randomly choose a server from the server list
+
+ Args:
+ server_list (list[Server]): list of candidate servers
+
+ Returns:
+ Tuple[bytes, int]: (host, port) pair for the chosen server
+ """
+ if not server_list:
+ raise RuntimeError("pick_server_from_list called with empty list")
+
+ # TODO: currently we only use the lowest-priority servers. We should maintain a
+ # cache of servers known to be "down" and filter them out
+
+ min_priority = min(s.priority for s in server_list)
+ eligible_servers = list(s for s in server_list if s.priority == min_priority)
+ total_weight = sum(s.weight for s in eligible_servers)
+ target_weight = random.randint(0, total_weight)
+
+ for s in eligible_servers:
+ target_weight -= s.weight
+
+ if target_weight <= 0:
+ return s.host, s.port
+
+ # this should be impossible.
+ raise RuntimeError(
+ "pick_server_from_list got to end of eligible server list.",
+ )
+
+
+class SrvResolver(object):
+ """Interface to the dns client to do SRV lookups, with result caching.
+
+ The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
+ but the cache never gets populated), so we add our own caching layer here.
+
+ Args:
+ dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
+ cache (dict): cache object
+ get_time (callable): clock implementation. Should return seconds since the epoch
+ """
+ def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+ self._dns_client = dns_client
+ self._cache = cache
+ self._get_time = get_time
+
+ @defer.inlineCallbacks
+ def resolve_service(self, service_name):
+ """Look up a SRV record
+
+ Args:
+ service_name (bytes): record to look up
+
+ Returns:
+ Deferred[list[Server]]:
+ a list of the SRV records, or an empty list if none found
+ """
+ now = int(self._get_time())
+
+ if not isinstance(service_name, bytes):
+ raise TypeError("%r is not a byte string" % (service_name,))
+
+ cache_entry = self._cache.get(service_name, None)
+ if cache_entry:
+ if all(s.expires > now for s in cache_entry):
+ servers = list(cache_entry)
+ defer.returnValue(servers)
+
+ try:
+ answers, _, _ = yield make_deferred_yieldable(
+ self._dns_client.lookupService(service_name),
+ )
+ except DNSNameError:
+ # TODO: cache this. We can get the SOA out of the exception, and use
+ # the negative-TTL value.
+ defer.returnValue([])
+ except DomainError as e:
+ # We failed to resolve the name (other than a NameError)
+ # Try something in the cache, else rereaise
+ cache_entry = self._cache.get(service_name, None)
+ if cache_entry:
+ logger.warn(
+ "Failed to resolve %r, falling back to cache. %r",
+ service_name, e
+ )
+ defer.returnValue(list(cache_entry))
+ else:
+ raise e
+
+ if (len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name(b'.')):
+ raise ConnectError("Service %s unavailable" % service_name)
+
+ servers = []
+
+ for answer in answers:
+ if answer.type != dns.SRV or not answer.payload:
+ continue
+
+ payload = answer.payload
+
+ servers.append(Server(
+ host=payload.target.name,
+ port=payload.port,
+ priority=payload.priority,
+ weight=payload.weight,
+ expires=now + answer.ttl,
+ ))
+
+ self._cache[service_name] = list(servers)
+ defer.returnValue(servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 24b6110c20..1682c9af13 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,7 +19,7 @@ import random
import sys
from io import BytesIO
-from six import PY3, string_types
+from six import PY3, raise_from, string_types
from six.moves import urllib
import attr
@@ -32,7 +32,6 @@ from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
-from twisted.web.client import Agent, FileBodyProducer, HTTPConnectionPool
from twisted.web.http_headers import Headers
import synapse.metrics
@@ -41,9 +40,11 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
+ RequestSendFailed,
SynapseError,
)
-from synapse.http.endpoint import matrix_federation_endpoint
+from synapse.http import QuieterFileBodyProducer
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
@@ -65,20 +66,6 @@ else:
MAXINT = sys.maxint
-class MatrixFederationEndpointFactory(object):
- def __init__(self, hs):
- self.reactor = hs.get_reactor()
- self.tls_client_options_factory = hs.tls_client_options_factory
-
- def endpointForURI(self, uri):
- destination = uri.netloc.decode('ascii')
-
- return matrix_federation_endpoint(
- self.reactor, destination, timeout=10,
- tls_client_options_factory=self.tls_client_options_factory
- )
-
-
_next_id = 1
@@ -181,17 +168,15 @@ class MatrixFederationHttpClient(object):
requests.
"""
- def __init__(self, hs):
+ def __init__(self, hs, tls_client_options_factory):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
reactor = hs.get_reactor()
- pool = HTTPConnectionPool(reactor)
- pool.retryAutomatically = False
- pool.maxPersistentPerHost = 5
- pool.cachedConnectionTimeout = 2 * 60
- self.agent = Agent.usingEndpointFactory(
- reactor, MatrixFederationEndpointFactory(hs), pool=pool
+
+ self.agent = MatrixFederationAgent(
+ hs.get_reactor(),
+ tls_client_options_factory,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
@@ -228,19 +213,18 @@ class MatrixFederationHttpClient(object):
backoff_on_404 (bool): Back off if we get a 404
Returns:
- Deferred: resolves with the http response object on success.
-
- Fails with ``HttpResponseException``: if we get an HTTP response
- code >= 300.
-
- Fails with ``NotRetryingDestination`` if we are not yet ready
- to retry this server.
-
- Fails with ``FederationDeniedError`` if this destination
- is not on our federation whitelist
-
- (May also fail with plenty of other Exceptions for things like DNS
- failures, connection failures, SSL failures.)
+ Deferred[twisted.web.client.Response]: resolves with the HTTP
+ response object on success.
+
+ Raises:
+ HttpResponseException: If we get an HTTP response code >= 300
+ (except 429).
+ NotRetryingDestination: If we are not yet ready to retry this
+ server.
+ FederationDeniedError: If this destination is not on our
+ federation whitelist
+ RequestSendFailed: If there were problems connecting to the
+ remote, due to e.g. DNS failures, connection timeouts etc.
"""
if timeout:
_sec_timeout = timeout / 1000
@@ -271,7 +255,6 @@ class MatrixFederationHttpClient(object):
headers_dict = {
b"User-Agent": [self.version_string_bytes],
- b"Host": [destination_bytes],
}
with limiter:
@@ -298,60 +281,111 @@ class MatrixFederationHttpClient(object):
json = request.get_json()
if json:
headers_dict[b"Content-Type"] = [b"application/json"]
- self.sign_request(
+ auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes,
- headers_dict, json,
+ json,
)
data = encode_canonical_json(json)
- producer = FileBodyProducer(
+ producer = QuieterFileBodyProducer(
BytesIO(data),
cooperator=self._cooperator,
)
else:
producer = None
- self.sign_request(
+ auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes,
- headers_dict,
)
+ headers_dict[b"Authorization"] = auth_headers
+
logger.info(
- "{%s} [%s] Sending request: %s %s",
+ "{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id, request.destination, request.method,
- url_str,
+ url_str, _sec_timeout,
)
- # we don't want all the fancy cookie and redirect handling that
- # treq.request gives: just use the raw Agent.
- request_deferred = self.agent.request(
- method_bytes,
- url_bytes,
- headers=Headers(headers_dict),
- bodyProducer=producer,
- )
+ try:
+ with Measure(self.clock, "outbound_request"):
+ # we don't want all the fancy cookie and redirect handling
+ # that treq.request gives: just use the raw Agent.
+ request_deferred = self.agent.request(
+ method_bytes,
+ url_bytes,
+ headers=Headers(headers_dict),
+ bodyProducer=producer,
+ )
+
+ request_deferred = timeout_deferred(
+ request_deferred,
+ timeout=_sec_timeout,
+ reactor=self.hs.get_reactor(),
+ )
+
+ response = yield request_deferred
+ except DNSLookupError as e:
+ raise_from(RequestSendFailed(e, can_retry=retry_on_dns_fail), e)
+ except Exception as e:
+ logger.info("Failed to send request: %s", e)
+ raise_from(RequestSendFailed(e, can_retry=True), e)
- request_deferred = timeout_deferred(
- request_deferred,
- timeout=_sec_timeout,
- reactor=self.hs.get_reactor(),
+ logger.info(
+ "{%s} [%s] Got response headers: %d %s",
+ request.txn_id,
+ request.destination,
+ response.code,
+ response.phrase.decode('ascii', errors='replace'),
)
- with Measure(self.clock, "outbound_request"):
- response = yield make_deferred_yieldable(
- request_deferred,
+ if 200 <= response.code < 300:
+ pass
+ else:
+ # :'(
+ # Update transactions table?
+ d = treq.content(response)
+ d = timeout_deferred(
+ d,
+ timeout=_sec_timeout,
+ reactor=self.hs.get_reactor(),
+ )
+
+ try:
+ body = yield make_deferred_yieldable(d)
+ except Exception as e:
+ # Eh, we're already going to raise an exception so lets
+ # ignore if this fails.
+ logger.warn(
+ "{%s} [%s] Failed to get error response: %s %s: %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ url_str,
+ _flatten_response_never_received(e),
+ )
+ body = None
+
+ e = HttpResponseException(
+ response.code, response.phrase, body
)
+ # Retry if the error is a 429 (Too Many Requests),
+ # otherwise just raise a standard HttpResponseException
+ if response.code == 429:
+ raise_from(RequestSendFailed(e, can_retry=True), e)
+ else:
+ raise e
+
break
- except Exception as e:
+ except RequestSendFailed as e:
logger.warn(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
request.method,
url_str,
- _flatten_response_never_received(e),
+ _flatten_response_never_received(e.inner_exception),
)
- if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+ if not e.can_retry:
raise
if retries_left and not timeout:
@@ -376,50 +410,36 @@ class MatrixFederationHttpClient(object):
else:
raise
- logger.info(
- "{%s} [%s] Got response headers: %d %s",
- request.txn_id,
- request.destination,
- response.code,
- response.phrase.decode('ascii', errors='replace'),
- )
-
- if 200 <= response.code < 300:
- pass
- else:
- # :'(
- # Update transactions table?
- d = treq.content(response)
- d = timeout_deferred(
- d,
- timeout=_sec_timeout,
- reactor=self.hs.get_reactor(),
- )
- body = yield make_deferred_yieldable(d)
- raise HttpResponseException(
- response.code, response.phrase, body
- )
+ except Exception as e:
+ logger.warn(
+ "{%s} [%s] Request failed: %s %s: %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ url_str,
+ _flatten_response_never_received(e),
+ )
+ raise
defer.returnValue(response)
- def sign_request(self, destination, method, url_bytes, headers_dict,
- content=None, destination_is=None):
+ def build_auth_headers(
+ self, destination, method, url_bytes, content=None, destination_is=None,
+ ):
"""
- Signs a request by adding an Authorization header to headers_dict
+ Builds the Authorization headers for a federation request
Args:
destination (bytes|None): The desination home server of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
method (bytes): The HTTP method of the request
url_bytes (bytes): The URI path of the request
- headers_dict (dict[bytes, list[bytes]]): Dictionary of request headers to
- append to
content (object): The body of the request
destination_is (bytes): As 'destination', but if the destination is an
identity server
Returns:
- None
+ list[bytes]: a list of headers to be added as "Authorization:" headers
"""
request = {
"method": method,
@@ -446,8 +466,7 @@ class MatrixFederationHttpClient(object):
self.server_name, key, sig,
)).encode('ascii')
)
-
- headers_dict[b"Authorization"] = auth_headers
+ return auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, args={}, data={},
@@ -477,17 +496,18 @@ class MatrixFederationHttpClient(object):
requests)
Returns:
- Deferred: Succeeds when we get a 2xx HTTP response. The result
- will be the decoded JSON body.
-
- Fails with ``HttpResponseException`` if we get an HTTP response
- code >= 300.
-
- Fails with ``NotRetryingDestination`` if we are not yet ready
- to retry this server.
-
- Fails with ``FederationDeniedError`` if this destination
- is not on our federation whitelist
+ Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
+ result will be the decoded JSON body.
+
+ Raises:
+ HttpResponseException: If we get an HTTP response code >= 300
+ (except 429).
+ NotRetryingDestination: If we are not yet ready to retry this
+ server.
+ FederationDeniedError: If this destination is not on our
+ federation whitelist
+ RequestSendFailed: If there were problems connecting to the
+ remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
@@ -531,17 +551,18 @@ class MatrixFederationHttpClient(object):
try the request anyway.
args (dict): query params
Returns:
- Deferred: Succeeds when we get a 2xx HTTP response. The result
- will be the decoded JSON body.
-
- Fails with ``HttpResponseException`` if we get an HTTP response
- code >= 300.
-
- Fails with ``NotRetryingDestination`` if we are not yet ready
- to retry this server.
-
- Fails with ``FederationDeniedError`` if this destination
- is not on our federation whitelist
+ Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
+ result will be the decoded JSON body.
+
+ Raises:
+ HttpResponseException: If we get an HTTP response code >= 300
+ (except 429).
+ NotRetryingDestination: If we are not yet ready to retry this
+ server.
+ FederationDeniedError: If this destination is not on our
+ federation whitelist
+ RequestSendFailed: If there were problems connecting to the
+ remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
@@ -586,17 +607,18 @@ class MatrixFederationHttpClient(object):
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns:
- Deferred: Succeeds when we get a 2xx HTTP response. The result
- will be the decoded JSON body.
-
- Fails with ``HttpResponseException`` if we get an HTTP response
- code >= 300.
-
- Fails with ``NotRetryingDestination`` if we are not yet ready
- to retry this server.
-
- Fails with ``FederationDeniedError`` if this destination
- is not on our federation whitelist
+ Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
+ result will be the decoded JSON body.
+
+ Raises:
+ HttpResponseException: If we get an HTTP response code >= 300
+ (except 429).
+ NotRetryingDestination: If we are not yet ready to retry this
+ server.
+ FederationDeniedError: If this destination is not on our
+ federation whitelist
+ RequestSendFailed: If there were problems connecting to the
+ remote, due to e.g. DNS failures, connection timeouts etc.
"""
logger.debug("get_json args: %s", args)
@@ -637,17 +659,18 @@ class MatrixFederationHttpClient(object):
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
Returns:
- Deferred: Succeeds when we get a 2xx HTTP response. The result
- will be the decoded JSON body.
-
- Fails with ``HttpResponseException`` if we get an HTTP response
- code >= 300.
-
- Fails with ``NotRetryingDestination`` if we are not yet ready
- to retry this server.
-
- Fails with ``FederationDeniedError`` if this destination
- is not on our federation whitelist
+ Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
+ result will be the decoded JSON body.
+
+ Raises:
+ HttpResponseException: If we get an HTTP response code >= 300
+ (except 429).
+ NotRetryingDestination: If we are not yet ready to retry this
+ server.
+ FederationDeniedError: If this destination is not on our
+ federation whitelist
+ RequestSendFailed: If there were problems connecting to the
+ remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
method="DELETE",
@@ -680,18 +703,20 @@ class MatrixFederationHttpClient(object):
args (dict): Optional dictionary used to create the query string.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
- Returns:
- Deferred: resolves with an (int,dict) tuple of the file length and
- a dict of the response headers.
-
- Fails with ``HttpResponseException`` if we get an HTTP response code
- >= 300
-
- Fails with ``NotRetryingDestination`` if we are not yet ready
- to retry this server.
- Fails with ``FederationDeniedError`` if this destination
- is not on our federation whitelist
+ Returns:
+ Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of
+ the file length and a dict of the response headers.
+
+ Raises:
+ HttpResponseException: If we get an HTTP response code >= 300
+ (except 429).
+ NotRetryingDestination: If we are not yet ready to retry this
+ server.
+ FederationDeniedError: If this destination is not on our
+ federation whitelist
+ RequestSendFailed: If there were problems connecting to the
+ remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
method="GET",
@@ -784,21 +809,21 @@ def check_content_type_is_json(headers):
headers (twisted.web.http_headers.Headers): headers to check
Raises:
- RuntimeError if the
+ RequestSendFailed: if the Content-Type header is missing or isn't JSON
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
- raise RuntimeError(
+ raise RequestSendFailed(RuntimeError(
"No Content-Type header"
- )
+ ), can_retry=False)
c_type = c_type[0].decode('ascii') # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
- raise RuntimeError(
+ raise RequestSendFailed(RuntimeError(
"Content-Type not application/json: was '%s'" % c_type
- )
+ ), can_retry=False)
def encode_query_args(args):
diff --git a/synapse/http/server.py b/synapse/http/server.py
index b4b25cab19..16fb7935da 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -106,10 +106,10 @@ def wrap_json_request_handler(h):
# trace.
f = failure.Failure()
logger.error(
- "Failed handle request via %r: %r: %s",
- h,
+ "Failed handle request via %r: %r",
+ request.request_metrics.name,
request,
- f.getTraceback().rstrip(),
+ exc_info=(f.type, f.value, f.getTracebackObject()),
)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
@@ -169,18 +169,18 @@ def _return_html_error(f, request):
)
else:
logger.error(
- "Failed handle request %r: %s",
+ "Failed handle request %r",
request,
- f.getTraceback().rstrip(),
+ exc_info=(f.type, f.value, f.getTracebackObject()),
)
else:
code = http_client.INTERNAL_SERVER_ERROR
msg = "Internal server error"
logger.error(
- "Failed handle request %r: %s",
+ "Failed handle request %r",
request,
- f.getTraceback().rstrip(),
+ exc_info=(f.type, f.value, f.getTracebackObject()),
)
body = HTML_ERROR_TEMPLATE.format(
@@ -468,13 +468,13 @@ def set_cors_headers(request):
Args:
request (twisted.web.http.Request): The http request to add CORs to.
"""
- request.setHeader("Access-Control-Allow-Origin", "*")
+ request.setHeader(b"Access-Control-Allow-Origin", b"*")
request.setHeader(
- "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
+ b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS"
)
request.setHeader(
- "Access-Control-Allow-Headers",
- "Origin, X-Requested-With, Content-Type, Accept, Authorization"
+ b"Access-Control-Allow-Headers",
+ b"Origin, X-Requested-With, Content-Type, Accept, Authorization"
)
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index a1e4b88e6d..528125e737 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -121,16 +121,15 @@ def parse_string(request, name, default=None, required=False,
Args:
request: the twisted HTTP request.
- name (bytes/unicode): the name of the query parameter.
- default (bytes/unicode|None): value to use if the parameter is absent,
+ name (bytes|unicode): the name of the query parameter.
+ default (bytes|unicode|None): value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
- allowed_values (list[bytes/unicode]): List of allowed values for the
+ allowed_values (list[bytes|unicode]): List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
- encoding: The encoding to decode the name to, and decode the string
- content with.
+ encoding (str|None): The encoding to decode the string content with.
Returns:
bytes/unicode|None: A string value or the default. Unicode if encoding
|