diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 4775f6707d..442696d393 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
import collections
import logging
import random
+import time
logger = logging.getLogger(__name__)
@@ -31,7 +32,7 @@ SERVER_CACHE = {}
_Server = collections.namedtuple(
- "_Server", "priority weight host port"
+ "_Server", "priority weight host port expires"
)
@@ -74,6 +75,41 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
+class SpiderEndpoint(object):
+ """An endpoint which refuses to connect to blacklisted IP addresses
+ Implements twisted.internet.interfaces.IStreamClientEndpoint.
+ """
+ def __init__(self, reactor, host, port, blacklist, whitelist,
+ endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
+ self.reactor = reactor
+ self.host = host
+ self.port = port
+ self.blacklist = blacklist
+ self.whitelist = whitelist
+ self.endpoint = endpoint
+ self.endpoint_kw_args = endpoint_kw_args
+
+ @defer.inlineCallbacks
+ def connect(self, protocolFactory):
+ address = yield self.reactor.resolve(self.host)
+
+ from netaddr import IPAddress
+ ip_address = IPAddress(address)
+
+ if ip_address in self.blacklist:
+ if self.whitelist is None or ip_address not in self.whitelist:
+ raise ConnectError(
+ "Refusing to spider blacklisted IP address %s" % address
+ )
+
+ logger.info("Connecting to %s:%s", address, self.port)
+ endpoint = self.endpoint(
+ self.reactor, address, self.port, **self.endpoint_kw_args
+ )
+ connection = yield endpoint.connect(protocolFactory)
+ defer.returnValue(connection)
+
+
class SRVClientEndpoint(object):
"""An endpoint which looks up SRV records for a service.
Cycles through the list of servers starting with each call to connect
@@ -92,7 +128,8 @@ class SRVClientEndpoint(object):
host=domain,
port=default_port,
priority=0,
- weight=0
+ weight=0,
+ expires=0,
)
else:
self.default_server = None
@@ -118,7 +155,7 @@ class SRVClientEndpoint(object):
return self.default_server
else:
raise ConnectError(
- "Not server available for %s", self.service_name
+ "Not server available for %s" % self.service_name
)
min_priority = self.servers[0].priority
@@ -153,7 +190,13 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
+def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
+ cache_entry = cache.get(service_name, None)
+ if cache_entry:
+ if all(s.expires > int(clock.time()) for s in cache_entry):
+ servers = list(cache_entry)
+ defer.returnValue(servers)
+
servers = []
try:
@@ -166,34 +209,33 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name('.')):
- raise ConnectError("Service %s unavailable", service_name)
+ raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
-
host = str(payload.target)
+ srv_ttl = answer.ttl
try:
answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError:
continue
- ips = [
- answer.payload.dottedQuad()
- for answer in answers
- if answer.type == dns.A and answer.payload
- ]
-
- for ip in ips:
- servers.append(_Server(
- host=ip,
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight)
- ))
+ for answer in answers:
+ if answer.type == dns.A and answer.payload:
+ ip = answer.payload.dottedQuad()
+ host_ttl = min(srv_ttl, answer.ttl)
+
+ servers.append(_Server(
+ host=ip,
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight),
+ expires=int(clock.time()) + host_ttl,
+ ))
servers.sort()
cache[service_name] = list(servers)
|