summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/http/endpoint.py116
-rw-r--r--tests/test_dns.py26
2 files changed, 118 insertions, 24 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index d8923c9abb..241b17f2cb 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import socket
 
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet import defer, reactor
@@ -30,7 +31,10 @@ logger = logging.getLogger(__name__)
 
 SERVER_CACHE = {}
 
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is actually a dotted-quad or ipv6 address string. Except when there's
+# no SRV record, in which case it is the original hostname.
 _Server = collections.namedtuple(
     "_Server", "priority weight host port expires"
 )
@@ -219,9 +223,10 @@ class SRVClientEndpoint(object):
                 return self.default_server
             else:
                 raise ConnectError(
-                    "Not server available for %s" % self.service_name
+                    "No server available for %s" % self.service_name
                 )
 
+        # look for all servers with the same priority
         min_priority = self.servers[0].priority
         weight_indexes = list(
             (index, server.weight + 1)
@@ -231,11 +236,22 @@ class SRVClientEndpoint(object):
 
         total_weight = sum(weight for index, weight in weight_indexes)
         target_weight = random.randint(0, total_weight)
-
         for index, weight in weight_indexes:
             target_weight -= weight
             if target_weight <= 0:
                 server = self.servers[index]
+                # XXX: this looks totally dubious:
+                #
+                # (a) we never reuse a server until we have been through
+                #     all of the servers at the same priority, so if the
+                #     weights are A: 100, B:1, we always do ABABAB instead of
+                #     AAAA...AAAB (approximately).
+                #
+                # (b) After using all the servers at the lowest priority,
+                #     we move onto the next priority. We should only use the
+                #     second priority if servers at the top priority are
+                #     unreachable.
+                #
                 del self.servers[index]
                 self.used_servers.append(server)
                 return server
@@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
                 continue
 
             payload = answer.payload
-            host = str(payload.target)
-            srv_ttl = answer.ttl
 
-            try:
-                answers, _, _ = yield dns_client.lookupAddress(host)
-            except DNSNameError:
-                continue
+            hosts = yield _get_hosts_for_srv_record(
+                dns_client, str(payload.target)
+            )
 
-            for answer in answers:
-                if answer.type == dns.A and answer.payload:
-                    ip = answer.payload.dottedQuad()
-                    host_ttl = min(srv_ttl, answer.ttl)
+            for (ip, ttl) in hosts:
+                host_ttl = min(answer.ttl, ttl)
 
-                    servers.append(_Server(
-                        host=ip,
-                        port=int(payload.port),
-                        priority=int(payload.priority),
-                        weight=int(payload.weight),
-                        expires=int(clock.time()) + host_ttl,
-                    ))
+                servers.append(_Server(
+                    host=ip,
+                    port=int(payload.port),
+                    priority=int(payload.priority),
+                    weight=int(payload.weight),
+                    expires=int(clock.time()) + host_ttl,
+                ))
 
         servers.sort()
         cache[service_name] = list(servers)
@@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
             raise e
 
     defer.returnValue(servers)
+
+
+@defer.inlineCallbacks
+def _get_hosts_for_srv_record(dns_client, host):
+    """Look up each of the hosts in a SRV record
+
+    Args:
+        dns_client (twisted.names.dns.IResolver):
+        host (basestring): host to look up
+
+    Returns:
+        Deferred[list[(str, int)]]: a list of (host, ttl) pairs
+
+    """
+    ip4_servers = []
+    ip6_servers = []
+
+    def cb(res):
+        # lookupAddress and lookupIP6Address return a three-tuple
+        # giving the answer, authority, and additional sections of the
+        # response.
+        #
+        # we only care about the answers.
+
+        return res[0]
+
+    def eb(res):
+        res.trap(DNSNameError)
+        return []
+
+    # no logcontexts here, so we can safely fire these off and gatherResults
+    d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
+    d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
+    results = yield defer.gatherResults([d1, d2], consumeErrors=True)
+
+    for result in results:
+        for answer in result:
+            if not answer.payload:
+                continue
+
+            try:
+                if answer.type == dns.A:
+                    ip = answer.payload.dottedQuad()
+                    ip4_servers.append((ip, answer.ttl))
+                elif answer.type == dns.AAAA:
+                    ip = socket.inet_ntop(
+                        socket.AF_INET6, answer.payload.address,
+                    )
+                    ip6_servers.append((ip, answer.ttl))
+                else:
+                    # the most likely candidate here is a CNAME record.
+                    # rfc2782 says srvs may not point to aliases.
+                    logger.warn(
+                        "Ignoring unexpected DNS record type %s for %s",
+                        answer.type, host,
+                    )
+                    continue
+            except Exception as e:
+                logger.warn("Ignoring invalid DNS response for %s: %s",
+                            host, e)
+                continue
+
+    # keep the ipv4 results before the ipv6 results, mostly to match historical
+    # behaviour.
+    defer.returnValue(ip4_servers + ip6_servers)
diff --git a/tests/test_dns.py b/tests/test_dns.py
index c394c57ee7..d08b0f4333 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service
 from tests.utils import MockClock
 
 
+@unittest.DEBUG
 class DnsTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_resolve(self):
         dns_client_mock = Mock()
 
-        service_name = "test_service.examle.com"
+        service_name = "test_service.example.com"
         host_name = "example.com"
         ip_address = "127.0.0.1"
+        ip6_address = "::1"
 
         answer_srv = dns.RRHeader(
             type=dns.SRV,
@@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase):
             )
         )
 
-        dns_client_mock.lookupService.return_value = ([answer_srv], None, None)
-        dns_client_mock.lookupAddress.return_value = ([answer_a], None, None)
+        answer_aaaa = dns.RRHeader(
+            type=dns.AAAA,
+            payload=dns.Record_AAAA(
+                address=ip6_address,
+            )
+        )
+
+        dns_client_mock.lookupService.return_value = defer.succeed(
+            ([answer_srv], None, None),
+        )
+        dns_client_mock.lookupAddress.return_value = defer.succeed(
+            ([answer_a], None, None),
+        )
+        dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
+            ([answer_aaaa], None, None),
+        )
 
         cache = {}
 
@@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase):
 
         dns_client_mock.lookupService.assert_called_once_with(service_name)
         dns_client_mock.lookupAddress.assert_called_once_with(host_name)
+        dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
 
-        self.assertEquals(len(servers), 1)
+        self.assertEquals(len(servers), 2)
         self.assertEquals(servers, cache[service_name])
         self.assertEquals(servers[0].host, ip_address)
+        self.assertEquals(servers[1].host, ip6_address)
 
     @defer.inlineCallbacks
     def test_from_cache_expired_and_dns_fail(self):