diff --git a/changelog.d/4427.misc b/changelog.d/4427.misc
new file mode 100644
index 0000000000..75500bdbc2
--- /dev/null
+++ b/changelog.d/4427.misc
@@ -0,0 +1 @@
+Refactor and cleanup for SRV record lookup
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 1c3b7ea28a..815f8ff2f7 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,29 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
import logging
import random
import re
-import time
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.error import ConnectError
-from twisted.names import client, dns
-from twisted.names.error import DNSNameError, DomainError
-logger = logging.getLogger(__name__)
-
-SERVER_CACHE = {}
+from synapse.http.federation.srv_resolver import Server, resolve_service
-# our record of an individual server which can be tried to reach a destination.
-#
-# "host" is the hostname acquired from the SRV record. Except when there's
-# no SRV record, in which case it is the original hostname.
-_Server = collections.namedtuple(
- "_Server", "priority weight host port expires"
-)
+logger = logging.getLogger(__name__)
def parse_server_name(server_name):
@@ -165,12 +153,9 @@ class SRVClientEndpoint(object):
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
if default_port is not None:
- self.default_server = _Server(
+ self.default_server = Server(
host=domain,
port=default_port,
- priority=0,
- weight=0,
- expires=0,
)
else:
self.default_server = None
@@ -240,57 +225,3 @@ class SRVClientEndpoint(object):
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)
-
-
-@defer.inlineCallbacks
-def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- if all(s.expires > int(clock.time()) for s in cache_entry):
- servers = list(cache_entry)
- defer.returnValue(servers)
-
- servers = []
-
- try:
- try:
- answers, _, _ = yield dns_client.lookupService(service_name)
- except DNSNameError:
- defer.returnValue([])
-
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
- raise ConnectError("Service %s unavailable" % service_name)
-
- for answer in answers:
- if answer.type != dns.SRV or not answer.payload:
- continue
-
- payload = answer.payload
-
- servers.append(_Server(
- host=str(payload.target),
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + answer.ttl,
- ))
-
- servers.sort()
- cache[service_name] = list(servers)
- except DomainError as e:
- # We failed to resolve the name (other than a NameError)
- # Try something in the cache, else rereaise
- cache_entry = cache.get(service_name, None)
- if cache_entry:
- logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
- )
- servers = list(cache_entry)
- else:
- raise e
-
- defer.returnValue(servers)
diff --git a/synapse/http/federation/__init__.py b/synapse/http/federation/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/synapse/http/federation/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
new file mode 100644
index 0000000000..c49b82c394
--- /dev/null
+++ b/synapse/http/federation/srv_resolver.py
@@ -0,0 +1,124 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import time
+
+import attr
+
+from twisted.internet import defer
+from twisted.internet.error import ConnectError
+from twisted.names import client, dns
+from twisted.names.error import DNSNameError, DomainError
+
+from synapse.util.logcontext import make_deferred_yieldable
+
+logger = logging.getLogger(__name__)
+
+SERVER_CACHE = {}
+
+
+@attr.s
+class Server(object):
+ """
+ Our record of an individual server which can be tried to reach a destination.
+
+ Attributes:
+ host (bytes): target hostname
+ port (int):
+ priority (int):
+ weight (int):
+ expires (int): when the cache should expire this record - in *seconds* since
+ the epoch
+ """
+ host = attr.ib()
+ port = attr.ib()
+ priority = attr.ib(default=0)
+ weight = attr.ib(default=0)
+ expires = attr.ib(default=0)
+
+
+@defer.inlineCallbacks
+def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
+ """Look up a SRV record, with caching
+
+ The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
+ but the cache never gets populated), so we add our own caching layer here.
+
+ Args:
+ service_name (unicode|bytes): record to look up
+ dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
+ cache (dict): cache object
+ clock (object): clock implementation. must provide a time() method.
+
+ Returns:
+ Deferred[list[Server]]: a list of the SRV records, or an empty list if none found
+ """
+ # TODO: the dns client handles both unicode names (encoding via idna) and pre-encoded
+ # byteses; however they will obviously end up as separate entries in the cache. We
+ # should pick one form and stick with it.
+ cache_entry = cache.get(service_name, None)
+ if cache_entry:
+ if all(s.expires > int(clock.time()) for s in cache_entry):
+ servers = list(cache_entry)
+ defer.returnValue(servers)
+
+ try:
+ answers, _, _ = yield make_deferred_yieldable(
+ dns_client.lookupService(service_name),
+ )
+ except DNSNameError:
+ # TODO: cache this. We can get the SOA out of the exception, and use
+ # the negative-TTL value.
+ defer.returnValue([])
+ except DomainError as e:
+ # We failed to resolve the name (other than a NameError)
+ # Try something in the cache, else rereaise
+ cache_entry = cache.get(service_name, None)
+ if cache_entry:
+ logger.warn(
+ "Failed to resolve %r, falling back to cache. %r",
+ service_name, e
+ )
+ defer.returnValue(list(cache_entry))
+ else:
+ raise e
+
+ if (len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name(b'.')):
+ raise ConnectError("Service %s unavailable" % service_name)
+
+ servers = []
+
+ for answer in answers:
+ if answer.type != dns.SRV or not answer.payload:
+ continue
+
+ payload = answer.payload
+
+ servers.append(Server(
+ host=payload.target.name,
+ port=payload.port,
+ priority=payload.priority,
+ weight=payload.weight,
+ expires=int(clock.time()) + answer.ttl,
+ ))
+
+ servers.sort() # FIXME: get rid of this (it's broken by the attrs change)
+ cache[service_name] = list(servers)
+ defer.returnValue(servers)
diff --git a/tests/http/federation/__init__.py b/tests/http/federation/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/tests/http/federation/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/test_dns.py b/tests/http/federation/test_srv_resolver.py
index 90bd34be34..1271a495e1 100644
--- a/tests/test_dns.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,40 +17,63 @@
from mock import Mock
from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.internet.error import ConnectError
from twisted.names import dns, error
-from synapse.http.endpoint import resolve_service
+from synapse.http.federation.srv_resolver import resolve_service
+from synapse.util.logcontext import LoggingContext
+from tests import unittest
from tests.utils import MockClock
-from . import unittest
-
-@unittest.DEBUG
-class DnsTestCase(unittest.TestCase):
- @defer.inlineCallbacks
+class SrvResolverTestCase(unittest.TestCase):
def test_resolve(self):
dns_client_mock = Mock()
- service_name = "test_service.example.com"
- host_name = "example.com"
+ service_name = b"test_service.example.com"
+ host_name = b"example.com"
answer_srv = dns.RRHeader(
type=dns.SRV, payload=dns.Record_SRV(target=host_name)
)
- dns_client_mock.lookupService.return_value = defer.succeed(
- ([answer_srv], None, None)
- )
+ result_deferred = Deferred()
+ dns_client_mock.lookupService.return_value = result_deferred
cache = {}
- servers = yield resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache
- )
+ @defer.inlineCallbacks
+ def do_lookup():
+ with LoggingContext("one") as ctx:
+ resolve_d = resolve_service(
+ service_name, dns_client=dns_client_mock, cache=cache
+ )
+
+ self.assertNoResult(resolve_d)
+
+ # should have reset to the sentinel context
+ self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+
+ result = yield resolve_d
+
+ # should have restored our context
+ self.assertIs(LoggingContext.current_context(), ctx)
+
+ defer.returnValue(result)
+
+ test_d = do_lookup()
+ self.assertNoResult(test_d)
dns_client_mock.lookupService.assert_called_once_with(service_name)
+ result_deferred.callback(
+ ([answer_srv], None, None)
+ )
+
+ servers = self.successResultOf(test_d)
+
self.assertEquals(len(servers), 1)
self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, host_name)
@@ -127,3 +151,59 @@ class DnsTestCase(unittest.TestCase):
self.assertEquals(len(servers), 0)
self.assertEquals(len(cache), 0)
+
+ def test_disabled_service(self):
+ """
+ test the behaviour when there is a single record which is ".".
+ """
+ service_name = b"test_service.example.com"
+
+ lookup_deferred = Deferred()
+ dns_client_mock = Mock()
+ dns_client_mock.lookupService.return_value = lookup_deferred
+ cache = {}
+
+ resolve_d = resolve_service(
+ service_name, dns_client=dns_client_mock, cache=cache
+ )
+ self.assertNoResult(resolve_d)
+
+ # returning a single "." should make the lookup fail with a ConenctError
+ lookup_deferred.callback((
+ [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))],
+ None,
+ None,
+ ))
+
+ self.failureResultOf(resolve_d, ConnectError)
+
+ def test_non_srv_answer(self):
+ """
+ test the behaviour when the dns server gives us a spurious non-SRV response
+ """
+ service_name = b"test_service.example.com"
+
+ lookup_deferred = Deferred()
+ dns_client_mock = Mock()
+ dns_client_mock.lookupService.return_value = lookup_deferred
+ cache = {}
+
+ resolve_d = resolve_service(
+ service_name, dns_client=dns_client_mock, cache=cache
+ )
+ self.assertNoResult(resolve_d)
+
+ lookup_deferred.callback((
+ [
+ dns.RRHeader(type=dns.A, payload=dns.Record_A()),
+ dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")),
+ ],
+ None,
+ None,
+ ))
+
+ servers = self.successResultOf(resolve_d)
+
+ self.assertEquals(len(servers), 1)
+ self.assertEquals(servers, cache[service_name])
+ self.assertEquals(servers[0].host, b"host")
|