diff --git a/synapse/http/client.py b/synapse/http/client.py
index cbd45b2bbe..6c89b20984 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -15,17 +15,24 @@
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
-from synapse.api.errors import CodeMessageException
+from synapse.api.errors import (
+ CodeMessageException, SynapseError, Codes,
+)
from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics
+from synapse.http.endpoint import SpiderEndpoint
from canonicaljson import encode_canonical_json
-from twisted.internet import defer, reactor, ssl
+from twisted.internet import defer, reactor, ssl, protocol
+from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.web.client import (
- Agent, readBody, FileBodyProducer, PartialDownloadError,
+ BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
+ readBody, FileBodyProducer, PartialDownloadError,
)
+from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
+from twisted.web._newclient import ResponseDone
from StringIO import StringIO
@@ -238,6 +245,107 @@ class SimpleHttpClient(object):
else:
raise CodeMessageException(response.code, body)
+ # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
+ # The two should be factored out.
+
+ @defer.inlineCallbacks
+ def get_file(self, url, output_stream, max_size=None):
+ """GETs a file from a given URL
+ Args:
+ url (str): The URL to GET
+ output_stream (file): File to write the response body to.
+ Returns:
+ A (int,dict,string,int) tuple of the file length, dict of the response
+ headers, absolute URI of the response and HTTP response code.
+ """
+
+ response = yield self.request(
+ "GET",
+ url.encode("ascii"),
+ headers=Headers({
+ b"User-Agent": [self.user_agent],
+ })
+ )
+
+ headers = dict(response.headers.getAllRawHeaders())
+
+ if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+ logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
+ raise SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (self.max_size,),
+ Codes.TOO_LARGE,
+ )
+
+ if response.code > 299:
+ logger.warn("Got %d when downloading %s" % (response.code, url))
+ raise SynapseError(
+ 502,
+ "Got error %d" % (response.code,),
+ Codes.UNKNOWN,
+ )
+
+ # TODO: if our Content-Type is HTML or something, just read the first
+ # N bytes into RAM rather than saving it all to disk only to read it
+ # straight back in again
+
+ try:
+ length = yield preserve_context_over_fn(
+ _readBodyToFile,
+ response, output_stream, max_size
+ )
+ except Exception as e:
+ logger.exception("Failed to download body")
+ raise SynapseError(
+ 502,
+ ("Failed to download remote body: %s" % e),
+ Codes.UNKNOWN,
+ )
+
+ defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+
+
+# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
+# The two should be factored out.
+
+class _ReadBodyToFileProtocol(protocol.Protocol):
+ def __init__(self, stream, deferred, max_size):
+ self.stream = stream
+ self.deferred = deferred
+ self.length = 0
+ self.max_size = max_size
+
+ def dataReceived(self, data):
+ self.stream.write(data)
+ self.length += len(data)
+ if self.max_size is not None and self.length >= self.max_size:
+ self.deferred.errback(SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (self.max_size,),
+ Codes.TOO_LARGE,
+ ))
+ self.deferred = defer.Deferred()
+ self.transport.loseConnection()
+
+ def connectionLost(self, reason):
+ if reason.check(ResponseDone):
+ self.deferred.callback(self.length)
+ elif reason.check(PotentialDataLoss):
+ # stolen from https://github.com/twisted/treq/pull/49/files
+ # http://twistedmatrix.com/trac/ticket/4840
+ self.deferred.callback(self.length)
+ else:
+ self.deferred.errback(reason)
+
+
+# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
+# The two should be factored out.
+
+def _readBodyToFile(response, stream, max_size):
+ d = defer.Deferred()
+ response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
+ return d
+
class CaptchaServerHttpClient(SimpleHttpClient):
"""
@@ -269,6 +377,59 @@ class CaptchaServerHttpClient(SimpleHttpClient):
defer.returnValue(e.response)
+class SpiderEndpointFactory(object):
+ def __init__(self, hs):
+ self.blacklist = hs.config.url_preview_ip_range_blacklist
+ self.policyForHTTPS = hs.get_http_client_context_factory()
+
+ def endpointForURI(self, uri):
+ logger.info("Getting endpoint for %s", uri.toBytes())
+ if uri.scheme == "http":
+ return SpiderEndpoint(
+ reactor, uri.host, uri.port, self.blacklist,
+ endpoint=TCP4ClientEndpoint,
+ endpoint_kw_args={
+ 'timeout': 15
+ },
+ )
+ elif uri.scheme == "https":
+ tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
+ return SpiderEndpoint(
+ reactor, uri.host, uri.port, self.blacklist,
+ endpoint=SSL4ClientEndpoint,
+ endpoint_kw_args={
+ 'sslContextFactory': tlsPolicy,
+ 'timeout': 15
+ },
+ )
+ else:
+ logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
+
+
+class SpiderHttpClient(SimpleHttpClient):
+ """
+ Separate HTTP client for spidering arbitrary URLs.
+ Special in that it follows retries and has a UA that looks
+ like a browser.
+
+ used by the preview_url endpoint in the content repo.
+ """
+ def __init__(self, hs):
+ SimpleHttpClient.__init__(self, hs)
+ # clobber the base class's agent and UA:
+ self.agent = ContentDecoderAgent(
+ BrowserLikeRedirectAgent(
+ Agent.usingEndpointFactory(
+ reactor,
+ SpiderEndpointFactory(hs)
+ )
+ ), [('gzip', GzipDecoder)]
+ )
+ # We could look like Chrome:
+ # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
+ # Chrome Safari" % hs.version_string)
+
+
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index bc28a2959a..a456dc19da 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -75,6 +75,37 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
+class SpiderEndpoint(object):
+ """An endpoint which refuses to connect to blacklisted IP addresses
+ Implements twisted.internet.interfaces.IStreamClientEndpoint.
+ """
+ def __init__(self, reactor, host, port, blacklist,
+ endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
+ self.reactor = reactor
+ self.host = host
+ self.port = port
+ self.blacklist = blacklist
+ self.endpoint = endpoint
+ self.endpoint_kw_args = endpoint_kw_args
+
+ @defer.inlineCallbacks
+ def connect(self, protocolFactory):
+ address = yield self.reactor.resolve(self.host)
+
+ from netaddr import IPAddress
+ if IPAddress(address) in self.blacklist:
+ raise ConnectError(
+ "Refusing to spider blacklisted IP address %s" % address
+ )
+
+ logger.info("Connecting to %s:%s", address, self.port)
+ endpoint = self.endpoint(
+ self.reactor, address, self.port, **self.endpoint_kw_args
+ )
+ connection = yield endpoint.connect(protocolFactory)
+ defer.returnValue(connection)
+
+
class SRVClientEndpoint(object):
"""An endpoint which looks up SRV records for a service.
Cycles through the list of servers starting with each call to connect
@@ -120,7 +151,7 @@ class SRVClientEndpoint(object):
return self.default_server
else:
raise ConnectError(
- "Not server available for %s", self.service_name
+ "Not server available for %s" % self.service_name
)
min_priority = self.servers[0].priority
@@ -174,7 +205,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name('.')):
- raise ConnectError("Service %s unavailable", service_name)
+ raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
|