summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2018-04-09 23:40:06 +0100
committerGitHub <noreply@github.com>2018-04-09 23:40:06 +0100
commit664adb4236d911d6b1619b8aa8489d99772ea86b (patch)
tree25d55210844c8e7457451737492fcdfd81856697
parentMerge pull request #3069 from krombel/update_prometheus_config (diff)
parentRemove address resolution of hosts in SRV records (diff)
downloadsynapse-664adb4236d911d6b1619b8aa8489d99772ea86b.tar.xz
Merge pull request #3016 from silkeh/improve-service-lookups
Improve handling of SRV records for federation connections
-rw-r--r--synapse/http/endpoint.py103
-rw-r--r--tests/test_dns.py29
2 files changed, 10 insertions, 122 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 87639b9151..00572c2897 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,8 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import socket
-
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet import defer, reactor
 from twisted.internet.error import ConnectError
@@ -33,7 +31,7 @@ SERVER_CACHE = {}
 
 # our record of an individual server which can be tried to reach a destination.
 #
-# "host" is actually a dotted-quad or ipv6 address string. Except when there's
+# "host" is the hostname acquired from the SRV record. Except when there's
 # no SRV record, in which case it is the original hostname.
 _Server = collections.namedtuple(
     "_Server", "priority weight host port expires"
@@ -297,20 +295,13 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
 
             payload = answer.payload
 
-            hosts = yield _get_hosts_for_srv_record(
-                dns_client, str(payload.target)
-            )
-
-            for (ip, ttl) in hosts:
-                host_ttl = min(answer.ttl, ttl)
-
-                servers.append(_Server(
-                    host=ip,
-                    port=int(payload.port),
-                    priority=int(payload.priority),
-                    weight=int(payload.weight),
-                    expires=int(clock.time()) + host_ttl,
-                ))
+            servers.append(_Server(
+                host=str(payload.target),
+                port=int(payload.port),
+                priority=int(payload.priority),
+                weight=int(payload.weight),
+                expires=int(clock.time()) + answer.ttl,
+            ))
 
         servers.sort()
         cache[service_name] = list(servers)
@@ -328,81 +319,3 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
             raise e
 
     defer.returnValue(servers)
-
-
-@defer.inlineCallbacks
-def _get_hosts_for_srv_record(dns_client, host):
-    """Look up each of the hosts in a SRV record
-
-    Args:
-        dns_client (twisted.names.dns.IResolver):
-        host (basestring): host to look up
-
-    Returns:
-        Deferred[list[(str, int)]]: a list of (host, ttl) pairs
-
-    """
-    ip4_servers = []
-    ip6_servers = []
-
-    def cb(res):
-        # lookupAddress and lookupIP6Address return a three-tuple
-        # giving the answer, authority, and additional sections of the
-        # response.
-        #
-        # we only care about the answers.
-
-        return res[0]
-
-    def eb(res, record_type):
-        if res.check(DNSNameError):
-            return []
-        logger.warn("Error looking up %s for %s: %s", record_type, host, res)
-        return res
-
-    # no logcontexts here, so we can safely fire these off and gatherResults
-    d1 = dns_client.lookupAddress(host).addCallbacks(
-        cb, eb, errbackArgs=("A", ))
-    d2 = dns_client.lookupIPV6Address(host).addCallbacks(
-        cb, eb, errbackArgs=("AAAA", ))
-    results = yield defer.DeferredList(
-        [d1, d2], consumeErrors=True)
-
-    # if all of the lookups failed, raise an exception rather than blowing out
-    # the cache with an empty result.
-    if results and all(s == defer.FAILURE for (s, _) in results):
-        defer.returnValue(results[0][1])
-
-    for (success, result) in results:
-        if success == defer.FAILURE:
-            continue
-
-        for answer in result:
-            if not answer.payload:
-                continue
-
-            try:
-                if answer.type == dns.A:
-                    ip = answer.payload.dottedQuad()
-                    ip4_servers.append((ip, answer.ttl))
-                elif answer.type == dns.AAAA:
-                    ip = socket.inet_ntop(
-                        socket.AF_INET6, answer.payload.address,
-                    )
-                    ip6_servers.append((ip, answer.ttl))
-                else:
-                    # the most likely candidate here is a CNAME record.
-                    # rfc2782 says srvs may not point to aliases.
-                    logger.warn(
-                        "Ignoring unexpected DNS record type %s for %s",
-                        answer.type, host,
-                    )
-                    continue
-            except Exception as e:
-                logger.warn("Ignoring invalid DNS response for %s: %s",
-                            host, e)
-                continue
-
-    # keep the ipv4 results before the ipv6 results, mostly to match historical
-    # behaviour.
-    defer.returnValue(ip4_servers + ip6_servers)
diff --git a/tests/test_dns.py b/tests/test_dns.py
index d08b0f4333..af607d626f 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -33,8 +33,6 @@ class DnsTestCase(unittest.TestCase):
 
         service_name = "test_service.example.com"
         host_name = "example.com"
-        ip_address = "127.0.0.1"
-        ip6_address = "::1"
 
         answer_srv = dns.RRHeader(
             type=dns.SRV,
@@ -43,29 +41,9 @@ class DnsTestCase(unittest.TestCase):
             )
         )
 
-        answer_a = dns.RRHeader(
-            type=dns.A,
-            payload=dns.Record_A(
-                address=ip_address,
-            )
-        )
-
-        answer_aaaa = dns.RRHeader(
-            type=dns.AAAA,
-            payload=dns.Record_AAAA(
-                address=ip6_address,
-            )
-        )
-
         dns_client_mock.lookupService.return_value = defer.succeed(
             ([answer_srv], None, None),
         )
-        dns_client_mock.lookupAddress.return_value = defer.succeed(
-            ([answer_a], None, None),
-        )
-        dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
-            ([answer_aaaa], None, None),
-        )
 
         cache = {}
 
@@ -74,13 +52,10 @@ class DnsTestCase(unittest.TestCase):
         )
 
         dns_client_mock.lookupService.assert_called_once_with(service_name)
-        dns_client_mock.lookupAddress.assert_called_once_with(host_name)
-        dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
 
-        self.assertEquals(len(servers), 2)
+        self.assertEquals(len(servers), 1)
         self.assertEquals(servers, cache[service_name])
-        self.assertEquals(servers[0].host, ip_address)
-        self.assertEquals(servers[1].host, ip6_address)
+        self.assertEquals(servers[0].host, host_name)
 
     @defer.inlineCallbacks
     def test_from_cache_expired_and_dns_fail(self):