diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/http/federation/__init__.py | 14 | ||||
-rw-r--r-- | tests/http/federation/test_srv_resolver.py (renamed from tests/test_dns.py) | 108 |
2 files changed, 108 insertions, 14 deletions
diff --git a/tests/http/federation/__init__.py b/tests/http/federation/__init__.py new file mode 100644 index 0000000000..1453d04571 --- /dev/null +++ b/tests/http/federation/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/test_dns.py b/tests/http/federation/test_srv_resolver.py index 90bd34be34..1271a495e1 100644 --- a/tests/test_dns.py +++ b/tests/http/federation/test_srv_resolver.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,40 +17,63 @@ from mock import Mock from twisted.internet import defer +from twisted.internet.defer import Deferred +from twisted.internet.error import ConnectError from twisted.names import dns, error -from synapse.http.endpoint import resolve_service +from synapse.http.federation.srv_resolver import resolve_service +from synapse.util.logcontext import LoggingContext +from tests import unittest from tests.utils import MockClock -from . import unittest - -@unittest.DEBUG -class DnsTestCase(unittest.TestCase): - @defer.inlineCallbacks +class SrvResolverTestCase(unittest.TestCase): def test_resolve(self): dns_client_mock = Mock() - service_name = "test_service.example.com" - host_name = "example.com" + service_name = b"test_service.example.com" + host_name = b"example.com" answer_srv = dns.RRHeader( type=dns.SRV, payload=dns.Record_SRV(target=host_name) ) - dns_client_mock.lookupService.return_value = defer.succeed( - ([answer_srv], None, None) - ) + result_deferred = Deferred() + dns_client_mock.lookupService.return_value = result_deferred cache = {} - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) + @defer.inlineCallbacks + def do_lookup(): + with LoggingContext("one") as ctx: + resolve_d = resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + + self.assertNoResult(resolve_d) + + # should have reset to the sentinel context + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + + result = yield resolve_d + + # should have restored our context + self.assertIs(LoggingContext.current_context(), ctx) + + defer.returnValue(result) + + test_d = do_lookup() + self.assertNoResult(test_d) dns_client_mock.lookupService.assert_called_once_with(service_name) + result_deferred.callback( + ([answer_srv], None, None) + ) + + servers = self.successResultOf(test_d) + self.assertEquals(len(servers), 1) self.assertEquals(servers, cache[service_name]) self.assertEquals(servers[0].host, host_name) @@ -127,3 +151,59 @@ class DnsTestCase(unittest.TestCase): self.assertEquals(len(servers), 0) self.assertEquals(len(cache), 0) + + def test_disabled_service(self): + """ + test the behaviour when there is a single record which is ".". + """ + service_name = b"test_service.example.com" + + lookup_deferred = Deferred() + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = lookup_deferred + cache = {} + + resolve_d = resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + self.assertNoResult(resolve_d) + + # returning a single "." should make the lookup fail with a ConenctError + lookup_deferred.callback(( + [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], + None, + None, + )) + + self.failureResultOf(resolve_d, ConnectError) + + def test_non_srv_answer(self): + """ + test the behaviour when the dns server gives us a spurious non-SRV response + """ + service_name = b"test_service.example.com" + + lookup_deferred = Deferred() + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = lookup_deferred + cache = {} + + resolve_d = resolve_service( + service_name, dns_client=dns_client_mock, cache=cache + ) + self.assertNoResult(resolve_d) + + lookup_deferred.callback(( + [ + dns.RRHeader(type=dns.A, payload=dns.Record_A()), + dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), + ], + None, + None, + )) + + servers = self.successResultOf(resolve_d) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + self.assertEquals(servers[0].host, b"host") |