diff options
Diffstat (limited to 'synapse/http/federation/srv_resolver.py')
-rw-r--r-- | synapse/http/federation/srv_resolver.py | 37 |
1 files changed, 20 insertions, 17 deletions
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 71830c549d..1f22f78a75 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -45,6 +45,7 @@ class Server(object): expires (int): when the cache should expire this record - in *seconds* since the epoch """ + host = attr.ib() port = attr.ib() priority = attr.ib(default=0) @@ -79,9 +80,7 @@ def pick_server_from_list(server_list): return s.host, s.port # this should be impossible. - raise RuntimeError( - "pick_server_from_list got to end of eligible server list.", - ) + raise RuntimeError("pick_server_from_list got to end of eligible server list.") class SrvResolver(object): @@ -95,6 +94,7 @@ class SrvResolver(object): cache (dict): cache object get_time (callable): clock implementation. Should return seconds since the epoch """ + def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): self._dns_client = dns_client self._cache = cache @@ -124,7 +124,7 @@ class SrvResolver(object): try: answers, _, _ = yield make_deferred_yieldable( - self._dns_client.lookupService(service_name), + self._dns_client.lookupService(service_name) ) except DNSNameError: # TODO: cache this. We can get the SOA out of the exception, and use @@ -136,17 +136,18 @@ class SrvResolver(object): cache_entry = self._cache.get(service_name, None) if cache_entry: logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e + "Failed to resolve %r, falling back to cache. %r", service_name, e ) defer.returnValue(list(cache_entry)) else: raise e - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): + if ( + len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b".") + ): raise ConnectError("Service %s unavailable" % service_name) servers = [] @@ -157,13 +158,15 @@ class SrvResolver(object): payload = answer.payload - servers.append(Server( - host=payload.target.name, - port=payload.port, - priority=payload.priority, - weight=payload.weight, - expires=now + answer.ttl, - )) + servers.append( + Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=now + answer.ttl, + ) + ) self._cache[service_name] = list(servers) defer.returnValue(servers) |