diff options
-rw-r--r-- | synapse/http/endpoint.py | 101 | ||||
-rw-r--r-- | tests/test_dns.py | 115 |
2 files changed, 186 insertions, 30 deletions
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index 4341ded96a..a9e024a415 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -17,7 +17,7 @@ from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint from twisted.internet import defer from twisted.internet.error import ConnectError from twisted.names import client, dns -from twisted.names.error import DNSNameError +from twisted.names.error import DNSNameError, DomainError import collections import logging @@ -27,6 +27,14 @@ import random logger = logging.getLogger(__name__) +SERVER_CACHE = {} + + +_Server = collections.namedtuple( + "_Server", "priority weight host port" +) + + def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, timeout=None): """Construct an endpoint for the given matrix destination. @@ -73,10 +81,6 @@ class SRVClientEndpoint(object): Implements twisted.internet.interfaces.IStreamClientEndpoint. """ - _Server = collections.namedtuple( - "_Server", "priority weight host port" - ) - def __init__(self, reactor, service, domain, protocol="tcp", default_port=None, endpoint=TCP4ClientEndpoint, endpoint_kw_args={}): @@ -101,32 +105,8 @@ class SRVClientEndpoint(object): @defer.inlineCallbacks def fetch_servers(self): - try: - answers, auth, add = yield client.lookupService(self.service_name) - except DNSNameError: - answers = [] - - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name('.')): - raise ConnectError("Service %s unavailable", self.service_name) - - self.servers = [] self.used_servers = [] - - for answer in answers: - if answer.type != dns.SRV or not answer.payload: - continue - payload = answer.payload - self.servers.append(self._Server( - host=str(payload.target), - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight) - )) - - self.servers.sort() + self.servers = yield resolve_service(self.service_name) def pick_server(self): if not self.servers: @@ -170,3 +150,64 @@ class SRVClientEndpoint(object): ) connection = yield endpoint.connect(protocolFactory) defer.returnValue(connection) + + +@defer.inlineCallbacks +def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE): + servers = [] + + try: + try: + answers, _, _ = yield dns_client.lookupService(service_name) + except DNSNameError: + defer.returnValue([]) + + if (len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name('.')): + raise ConnectError("Service %s unavailable", service_name) + + for answer in answers: + if answer.type != dns.SRV or not answer.payload: + continue + + payload = answer.payload + + host = str(payload.target) + + try: + answers, _, _ = yield dns_client.lookupAddress(host) + except DNSNameError: + continue + + ips = [ + answer.payload.dottedQuad() + for answer in answers + if answer.type == dns.A and answer.payload + ] + + for ip in ips: + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight) + )) + + servers.sort() + cache[service_name] = list(servers) + except DomainError as e: + # We failed to resolve the name (other than a NameError) + # Try something in the cache, else rereaise + cache_entry = cache.get(service_name, None) + if cache_entry: + logger.warn( + "Failed to resolve %r, falling back to cache. %r", + service_name, e + ) + servers = list(cache_entry) + else: + raise e + + defer.returnValue(servers) diff --git a/tests/test_dns.py b/tests/test_dns.py new file mode 100644 index 0000000000..637b1606f8 --- /dev/null +++ b/tests/test_dns.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import unittest +from twisted.internet import defer +from twisted.names import dns, error + +from mock import Mock + +from synapse.http.endpoint import resolve_service + + +class DnsTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def test_resolve(self): + dns_client_mock = Mock() + + service_name = "test_service.examle.com" + host_name = "example.com" + ip_address = "127.0.0.1" + + answer_srv = dns.RRHeader( + type=dns.SRV, + payload=dns.Record_SRV( + target=host_name, + ) + ) + + answer_a = dns.RRHeader( + type=dns.A, + payload=dns.Record_A( + address=ip_address, + ) + ) + + dns_client_mock.lookupService.return_value = ([answer_srv], None, None) + dns_client_mock.lookupAddress.return_value = ([answer_a], None, None) + + cache = {} + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + dns_client_mock.lookupService.assert_called_once_with(service_name) + dns_client_mock.lookupAddress.assert_called_once_with(host_name) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + self.assertEquals(servers[0].host, ip_address) + + @defer.inlineCallbacks + def test_from_cache(self): + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) + + service_name = "test_service.examle.com" + + cache = { + service_name: [object()] + } + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + dns_client_mock.lookupService.assert_called_once_with(service_name) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + + @defer.inlineCallbacks + def test_empty_cache(self): + dns_client_mock = Mock() + + dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) + + service_name = "test_service.examle.com" + + cache = {} + + with self.assertRaises(error.DNSServerError): + yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + @defer.inlineCallbacks + def test_name_error(self): + dns_client_mock = Mock() + + dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) + + service_name = "test_service.examle.com" + + cache = {} + + servers = yield resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + self.assertEquals(len(servers), 0) + self.assertEquals(len(cache), 0) |