diff options
Diffstat (limited to 'synapse/http/endpoint.py')
-rw-r--r-- | synapse/http/endpoint.py | 41 |
1 files changed, 24 insertions, 17 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index de5c762f50..a456dc19da 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError import collections import logging import random +import time logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ SERVER_CACHE = {} _Server = collections.namedtuple( - "_Server", "priority weight host port" + "_Server", "priority weight host port expires" ) @@ -123,7 +124,8 @@ class SRVClientEndpoint(object): host=domain, port=default_port, priority=0, - weight=0 + weight=0, + expires=0, ) else: self.default_server = None @@ -184,7 +186,13 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks -def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time): + cache_entry = cache.get(service_name, None) + if cache_entry: + if all(s.expires > int(clock.time()) for s in cache_entry): + servers = list(cache_entry) + defer.returnValue(servers) + servers = [] try: @@ -204,27 +212,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): continue payload = answer.payload - host = str(payload.target) + srv_ttl = answer.ttl try: answers, _, _ = yield dns_client.lookupAddress(host) except DNSNameError: continue - ips = [ - answer.payload.dottedQuad() - for answer in answers - if answer.type == dns.A and answer.payload - ] - - for ip in ips: - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) + for answer in answers: + if answer.type == dns.A and answer.payload: + ip = answer.payload.dottedQuad() + host_ttl = min(srv_ttl, answer.ttl) + + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) |