diff options
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r-- | synapse/http/endpoint.py | 76 |
1 files changed, 57 insertions, 19 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 4775f6707d..a456dc19da 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError import collections import logging import random +import time logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ SERVER_CACHE = {} _Server = collections.namedtuple( - "_Server", "priority weight host port" + "_Server", "priority weight host port expires" ) @@ -74,6 +75,37 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, return transport_endpoint(reactor, domain, port, **endpoint_kw_args) +class SpiderEndpoint(object): + """An endpoint which refuses to connect to blacklisted IP addresses + Implements twisted.internet.interfaces.IStreamClientEndpoint. + """ + def __init__(self, reactor, host, port, blacklist, + endpoint=TCP4ClientEndpoint, endpoint_kw_args={}): + self.reactor = reactor + self.host = host + self.port = port + self.blacklist = blacklist + self.endpoint = endpoint + self.endpoint_kw_args = endpoint_kw_args + + @defer.inlineCallbacks + def connect(self, protocolFactory): + address = yield self.reactor.resolve(self.host) + + from netaddr import IPAddress + if IPAddress(address) in self.blacklist: + raise ConnectError( + "Refusing to spider blacklisted IP address %s" % address + ) + + logger.info("Connecting to %s:%s", address, self.port) + endpoint = self.endpoint( + self.reactor, address, self.port, **self.endpoint_kw_args + ) + connection = yield endpoint.connect(protocolFactory) + defer.returnValue(connection) + + class SRVClientEndpoint(object): """An endpoint which looks up SRV records for a service. Cycles through the list of servers starting with each call to connect @@ -92,7 +124,8 @@ class SRVClientEndpoint(object): host=domain, port=default_port, priority=0, - weight=0 + weight=0, + expires=0, ) else: self.default_server = None @@ -118,7 +151,7 @@ class SRVClientEndpoint(object): return self.default_server else: raise ConnectError( - "Not server available for %s", self.service_name + "Not server available for %s" % self.service_name ) min_priority = self.servers[0].priority @@ -153,7 +186,13 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): + cache_entry = cache.get(service_name, None) + if cache_entry: + if all(s.expires > int(clock.time()) for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + servers = [] try: @@ -166,34 +205,33 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): and answers[0].type == dns.SRV and answers[0].payload and answers[0].payload.target == dns.Name('.')): - raise ConnectError("Service %s unavailable", service_name) + raise ConnectError("Service %s unavailable" % service_name) for answer in answers: if answer.type != dns.SRV or not answer.payload: continue payload = answer.payload - host = str(payload.target) + srv_ttl = answer.ttl try: answers, _, _ = yield dns_client.lookupAddress(host) except DNSNameError: continue - ips = [ - answer.payload.dottedQuad() - for answer in answers - if answer.type == dns.A and answer.payload - ] - - for ip in ips: - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) + for answer in answers: + if answer.type == dns.A and answer.payload: + ip = answer.payload.dottedQuad() + host_ttl = min(srv_ttl, answer.ttl) + + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) |