diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 4775f6707d..bc28a2959a 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
import collections
import logging
import random
+import time
logger = logging.getLogger(__name__)
@@ -31,7 +32,7 @@ SERVER_CACHE = {}
_Server = collections.namedtuple(
- "_Server", "priority weight host port"
+ "_Server", "priority weight host port expires"
)
@@ -92,7 +93,8 @@ class SRVClientEndpoint(object):
host=domain,
port=default_port,
priority=0,
- weight=0
+ weight=0,
+ expires=0,
)
else:
self.default_server = None
@@ -153,7 +155,13 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
+def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
+ cache_entry = cache.get(service_name, None)
+ if cache_entry:
+ if all(s.expires > int(clock.time()) for s in cache_entry):
+ servers = list(cache_entry)
+ defer.returnValue(servers)
+
servers = []
try:
@@ -173,27 +181,26 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
continue
payload = answer.payload
-
host = str(payload.target)
+ srv_ttl = answer.ttl
try:
answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError:
continue
- ips = [
- answer.payload.dottedQuad()
- for answer in answers
- if answer.type == dns.A and answer.payload
- ]
-
- for ip in ips:
- servers.append(_Server(
- host=ip,
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight)
- ))
+ for answer in answers:
+ if answer.type == dns.A and answer.payload:
+ ip = answer.payload.dottedQuad()
+ host_ttl = min(srv_ttl, answer.ttl)
+
+ servers.append(_Server(
+ host=ip,
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight),
+ expires=int(clock.time()) + host_ttl,
+ ))
servers.sort()
cache[service_name] = list(servers)
diff --git a/tests/test_dns.py b/tests/test_dns.py
index 637b1606f8..c394c57ee7 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -21,6 +21,8 @@ from mock import Mock
from synapse.http.endpoint import resolve_service
+from tests.utils import MockClock
+
class DnsTestCase(unittest.TestCase):
@@ -63,14 +65,17 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(servers[0].host, ip_address)
@defer.inlineCallbacks
- def test_from_cache(self):
+ def test_from_cache_expired_and_dns_fail(self):
dns_client_mock = Mock()
dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
service_name = "test_service.examle.com"
+ entry = Mock(spec_set=["expires"])
+ entry.expires = 0
+
cache = {
- service_name: [object()]
+ service_name: [entry]
}
servers = yield resolve_service(
@@ -83,6 +88,31 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(servers, cache[service_name])
@defer.inlineCallbacks
+ def test_from_cache(self):
+ clock = MockClock()
+
+ dns_client_mock = Mock(spec_set=['lookupService'])
+ dns_client_mock.lookupService = Mock(spec_set=[])
+
+ service_name = "test_service.examle.com"
+
+ entry = Mock(spec_set=["expires"])
+ entry.expires = 999999999
+
+ cache = {
+ service_name: [entry]
+ }
+
+ servers = yield resolve_service(
+ service_name, dns_client=dns_client_mock, cache=cache, clock=clock,
+ )
+
+ self.assertFalse(dns_client_mock.lookupService.called)
+
+ self.assertEquals(len(servers), 1)
+ self.assertEquals(servers, cache[service_name])
+
+ @defer.inlineCallbacks
def test_empty_cache(self):
dns_client_mock = Mock()
|