diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index d8923c9abb..241b17f2cb 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import socket
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor
@@ -30,7 +31,10 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is actually a dotted-quad or ipv6 address string. Except when there's
+# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
@@ -219,9 +223,10 @@ class SRVClientEndpoint(object):
return self.default_server
else:
raise ConnectError(
- "Not server available for %s" % self.service_name
+ "No server available for %s" % self.service_name
)
+ # look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
@@ -231,11 +236,22 @@ class SRVClientEndpoint(object):
total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)
-
for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
+ # XXX: this looks totally dubious:
+ #
+ # (a) we never reuse a server until we have been through
+ # all of the servers at the same priority, so if the
+ # weights are A: 100, B:1, we always do ABABAB instead of
+ # AAAA...AAAB (approximately).
+ #
+ # (b) After using all the servers at the lowest priority,
+ # we move onto the next priority. We should only use the
+ # second priority if servers at the top priority are
+ # unreachable.
+ #
del self.servers[index]
self.used_servers.append(server)
return server
@@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue
payload = answer.payload
- host = str(payload.target)
- srv_ttl = answer.ttl
- try:
- answers, _, _ = yield dns_client.lookupAddress(host)
- except DNSNameError:
- continue
+ hosts = yield _get_hosts_for_srv_record(
+ dns_client, str(payload.target)
+ )
- for answer in answers:
- if answer.type == dns.A and answer.payload:
- ip = answer.payload.dottedQuad()
- host_ttl = min(srv_ttl, answer.ttl)
+ for (ip, ttl) in hosts:
+ host_ttl = min(answer.ttl, ttl)
- servers.append(_Server(
- host=ip,
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + host_ttl,
- ))
+ servers.append(_Server(
+ host=ip,
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight),
+ expires=int(clock.time()) + host_ttl,
+ ))
servers.sort()
cache[service_name] = list(servers)
@@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
raise e
defer.returnValue(servers)
+
+
+@defer.inlineCallbacks
+def _get_hosts_for_srv_record(dns_client, host):
+ """Look up each of the hosts in a SRV record
+
+ Args:
+ dns_client (twisted.names.dns.IResolver):
+ host (basestring): host to look up
+
+ Returns:
+ Deferred[list[(str, int)]]: a list of (host, ttl) pairs
+
+ """
+ ip4_servers = []
+ ip6_servers = []
+
+ def cb(res):
+ # lookupAddress and lookupIP6Address return a three-tuple
+ # giving the answer, authority, and additional sections of the
+ # response.
+ #
+ # we only care about the answers.
+
+ return res[0]
+
+ def eb(res):
+ res.trap(DNSNameError)
+ return []
+
+ # no logcontexts here, so we can safely fire these off and gatherResults
+ d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
+ d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
+ results = yield defer.gatherResults([d1, d2], consumeErrors=True)
+
+ for result in results:
+ for answer in result:
+ if not answer.payload:
+ continue
+
+ try:
+ if answer.type == dns.A:
+ ip = answer.payload.dottedQuad()
+ ip4_servers.append((ip, answer.ttl))
+ elif answer.type == dns.AAAA:
+ ip = socket.inet_ntop(
+ socket.AF_INET6, answer.payload.address,
+ )
+ ip6_servers.append((ip, answer.ttl))
+ else:
+ # the most likely candidate here is a CNAME record.
+ # rfc2782 says srvs may not point to aliases.
+ logger.warn(
+ "Ignoring unexpected DNS record type %s for %s",
+ answer.type, host,
+ )
+ continue
+ except Exception as e:
+ logger.warn("Ignoring invalid DNS response for %s: %s",
+ host, e)
+ continue
+
+ # keep the ipv4 results before the ipv6 results, mostly to match historical
+ # behaviour.
+ defer.returnValue(ip4_servers + ip6_servers)
diff --git a/tests/test_dns.py b/tests/test_dns.py
index c394c57ee7..d08b0f4333 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service
from tests.utils import MockClock
+@unittest.DEBUG
class DnsTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve(self):
dns_client_mock = Mock()
- service_name = "test_service.examle.com"
+ service_name = "test_service.example.com"
host_name = "example.com"
ip_address = "127.0.0.1"
+ ip6_address = "::1"
answer_srv = dns.RRHeader(
type=dns.SRV,
@@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase):
)
)
- dns_client_mock.lookupService.return_value = ([answer_srv], None, None)
- dns_client_mock.lookupAddress.return_value = ([answer_a], None, None)
+ answer_aaaa = dns.RRHeader(
+ type=dns.AAAA,
+ payload=dns.Record_AAAA(
+ address=ip6_address,
+ )
+ )
+
+ dns_client_mock.lookupService.return_value = defer.succeed(
+ ([answer_srv], None, None),
+ )
+ dns_client_mock.lookupAddress.return_value = defer.succeed(
+ ([answer_a], None, None),
+ )
+ dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
+ ([answer_aaaa], None, None),
+ )
cache = {}
@@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase):
dns_client_mock.lookupService.assert_called_once_with(service_name)
dns_client_mock.lookupAddress.assert_called_once_with(host_name)
+ dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
- self.assertEquals(len(servers), 1)
+ self.assertEquals(len(servers), 2)
self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, ip_address)
+ self.assertEquals(servers[1].host, ip6_address)
@defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self):
|