From 290777b3d96df17292d40de240f7bd7b162fea4e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 18 Sep 2017 18:31:01 +0100 Subject: Clean up and document handling of logcontexts in Keyring (#2452) I'm still unclear on what the intended behaviour for `verify_json_objects_for_server` is, but at least I now understand the behaviour of most of the things it calls... --- tests/crypto/test_keyring.py | 74 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/crypto/test_keyring.py (limited to 'tests') diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py new file mode 100644 index 0000000000..da2c9e44e7 --- /dev/null +++ b/tests/crypto/test_keyring.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.crypto import keyring +from synapse.util.logcontext import LoggingContext +from tests import utils, unittest +from twisted.internet import defer + + +class KeyringTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield utils.setup_test_homeserver(handlers=None) + + @defer.inlineCallbacks + def test_wait_for_previous_lookups(self): + sentinel_context = LoggingContext.current_context() + + kr = keyring.Keyring(self.hs) + + def check_context(_, expected): + self.assertEquals( + LoggingContext.current_context().test_key, expected + ) + + lookup_1_deferred = defer.Deferred() + lookup_2_deferred = defer.Deferred() + + with LoggingContext("one") as context_one: + context_one.test_key = "one" + + wait_1_deferred = kr.wait_for_previous_lookups( + ["server1"], + {"server1": lookup_1_deferred}, + ) + + # there were no previous lookups, so the deferred should be ready + self.assertTrue(wait_1_deferred.called) + # ... so we should have preserved the LoggingContext. + self.assertIs(LoggingContext.current_context(), context_one) + wait_1_deferred.addBoth(check_context, "one") + + with LoggingContext("two") as context_two: + context_two.test_key = "two" + + # set off another wait. It should block because the first lookup + # hasn't yet completed. + wait_2_deferred = kr.wait_for_previous_lookups( + ["server1"], + {"server1": lookup_2_deferred}, + ) + self.assertFalse(wait_2_deferred.called) + # ... so we should have reset the LoggingContext. + self.assertIs(LoggingContext.current_context(), sentinel_context) + wait_2_deferred.addBoth(check_context, "two") + + # let the first lookup complete (in the sentinel context) + lookup_1_deferred.callback(None) + + # now the second wait should complete and restore our + # loggingcontext. + yield wait_2_deferred -- cgit 1.5.1 From aa620d09a01c226d7a6fbc0d839d8abd347a2b2e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Sep 2017 16:08:14 +0100 Subject: Add a config option to block all room invites (#2457) - allows sysadmins the ability to lock down their servers so that people can't send their users room invites. --- synapse/api/auth.py | 8 ++++++++ synapse/config/server.py | 10 ++++++++++ synapse/handlers/federation.py | 3 +++ synapse/handlers/room_member.py | 22 ++++++++++++++++++++++ tests/utils.py | 1 + 5 files changed, 44 insertions(+) (limited to 'tests') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e3da45b416..72858cca1f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -519,6 +519,14 @@ class Auth(object): ) def is_server_admin(self, user): + """ Check if the given user is a local server admin. + + Args: + user (str): mxid of user to check + + Returns: + bool: True if the user is an admin + """ return self.store.is_server_admin(user) @defer.inlineCallbacks diff --git a/synapse/config/server.py b/synapse/config/server.py index 89d61a0503..c9a1715f1f 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -43,6 +43,12 @@ class ServerConfig(Config): self.filter_timeline_limit = config.get("filter_timeline_limit", -1) + # Whether we should block invites sent to users on this server + # (other than those sent by local server admins) + self.block_non_admin_invites = config.get( + "block_non_admin_invites", False, + ) + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' @@ -194,6 +200,10 @@ class ServerConfig(Config): # and sync operations. The default value is -1, means no upper limit. # filter_timeline_limit: 5000 + # Whether room invites to users on this server should be blocked + # (except those sent by local server admins). The default is False. + # block_non_admin_invites: True + # List of ports that Synapse should listen on, their purpose and their # configuration. listeners: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2637f41dcd..18f87cad67 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1074,6 +1074,9 @@ class FederationHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") + if self.hs.config.block_non_admin_invites: + raise SynapseError(403, "This server does not accept room invites") + membership = event.content.get("membership") if event.type != EventTypes.Member or membership != Membership.INVITE: raise SynapseError(400, "The event was not an m.room.member invite event") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b3f979b246..9a498c2d3e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -191,6 +191,8 @@ class RoomMemberHandler(BaseHandler): if action in ["kick", "unban"]: effective_membership_state = "leave" + # if this is a join with a 3pid signature, we may need to turn a 3pid + # invite into a normal invite before we can handle the join. if third_party_signed is not None: replication = self.hs.get_replication_layer() yield replication.exchange_third_party_invite( @@ -208,6 +210,16 @@ class RoomMemberHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") + if (effective_membership_state == "invite" and + self.hs.config.block_non_admin_invites): + is_requester_admin = yield self.auth.is_server_admin( + requester.user, + ) + if not is_requester_admin: + raise SynapseError( + 403, "Invites have been disabled on this server", + ) + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, @@ -471,6 +483,16 @@ class RoomMemberHandler(BaseHandler): requester, txn_id ): + if self.hs.config.block_non_admin_invites: + is_requester_admin = yield self.auth.is_server_admin( + requester.user, + ) + if not is_requester_admin: + raise SynapseError( + 403, "Invites have been disabled on this server", + Codes.FORBIDDEN, + ) + invitee = yield self._lookup_3pid( id_server, medium, address ) diff --git a/tests/utils.py b/tests/utils.py index 4f7e32b3ab..3c81a3e16d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.worker_replication_url = "" config.worker_app = None config.email_enable_notifs = False + config.block_non_admin_invites = False config.use_frozen_dicts = True config.database_config = {"name": "sqlite3"} -- cgit 1.5.1 From 9864efa5321ad5afa522d9ecb3eb48e1f50fb852 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 19 Sep 2017 23:25:44 +0100 Subject: Fix concurrent server_key requests (#2458) Fix a bug where we could end up firing off multiple requests for server_keys for the same server at the same time. --- synapse/crypto/keyring.py | 4 ++- tests/crypto/test_keyring.py | 58 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 58 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 51851d04e5..ebf4e2e7a6 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -201,7 +201,9 @@ class Keyring(object): server_name = verify_request.server_name request_id = id(verify_request) server_to_request_ids.setdefault(server_name, set()).add(request_id) - deferred.addBoth(remove_deferreds, server_name, verify_request) + verify_request.deferred.addBoth( + remove_deferreds, server_name, verify_request, + ) # Pass those keys to handle_key_deferred so that the json object # signatures can be verified diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index da2c9e44e7..2e5878f087 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -12,17 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import signedjson +from mock import Mock +from synapse.api.errors import SynapseError from synapse.crypto import keyring +from synapse.util import async from synapse.util.logcontext import LoggingContext -from tests import utils, unittest +from tests import unittest, utils from twisted.internet import defer class KeyringTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield utils.setup_test_homeserver(handlers=None) + self.http_client = Mock() + self.hs = yield utils.setup_test_homeserver( + handlers=None, + http_client=self.http_client, + ) + self.hs.config.perspectives = { + "persp_server": {"k": "v"} + } @defer.inlineCallbacks def test_wait_for_previous_lookups(self): @@ -72,3 +82,45 @@ class KeyringTestCase(unittest.TestCase): # now the second wait should complete and restore our # loggingcontext. yield wait_2_deferred + + @defer.inlineCallbacks + def test_verify_json_objects_for_server_awaits_previous_requests(self): + key1 = signedjson.key.generate_signing_key(1) + + kr = keyring.Keyring(self.hs) + json1 = {} + signedjson.sign.sign_json(json1, "server1", key1) + + self.http_client.post_json.return_value = defer.Deferred() + + # start off a first set of lookups + res_deferreds = kr.verify_json_objects_for_server( + [("server1", json1), + ("server2", {}) + ] + ) + + # the unsigned json should be rejected pretty quickly + try: + yield res_deferreds[1] + self.assertFalse("unsigned json didn't cause a failure") + except SynapseError: + pass + + self.assertFalse(res_deferreds[0].called) + + # wait a tick for it to send the request to the perspectives server + # (it first tries the datastore) + yield async.sleep(0.005) + self.http_client.post_json.assert_called_once() + + # a second request for a server with outstanding requests should + # block rather than start a second call + self.http_client.post_json.reset_mock() + self.http_client.post_json.return_value = defer.Deferred() + + kr.verify_json_objects_for_server( + [("server1", json1)], + ) + yield async.sleep(0.005) + self.http_client.post_json.assert_not_called() -- cgit 1.5.1 From 72472456d82d956d957c4a68c23554f4b43eec54 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 20 Sep 2017 01:32:42 +0100 Subject: Add some more tests for Keyring --- tests/crypto/test_keyring.py | 177 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 140 insertions(+), 37 deletions(-) (limited to 'tests') diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 2e5878f087..570312da84 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -12,39 +12,72 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import signedjson +import time + +import signedjson.key +import signedjson.sign from mock import Mock from synapse.api.errors import SynapseError from synapse.crypto import keyring -from synapse.util import async +from synapse.util import async, logcontext from synapse.util.logcontext import LoggingContext from tests import unittest, utils from twisted.internet import defer +class MockPerspectiveServer(object): + def __init__(self): + self.server_name = "mock_server" + self.key = signedjson.key.generate_signing_key(0) + + def get_verify_keys(self): + vk = signedjson.key.get_verify_key(self.key) + return { + "%s:%s" % (vk.alg, vk.version): vk, + } + + def get_signed_key(self, server_name, verify_key): + key_id = "%s:%s" % (verify_key.alg, verify_key.version) + res = { + "server_name": server_name, + "old_verify_keys": {}, + "valid_until_ts": time.time() * 1000 + 3600, + "verify_keys": { + key_id: { + "key": signedjson.key.encode_verify_key_base64(verify_key) + } + } + } + signedjson.sign.sign_json(res, self.server_name, self.key) + return res + + class KeyringTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): + self.mock_perspective_server = MockPerspectiveServer() self.http_client = Mock() self.hs = yield utils.setup_test_homeserver( handlers=None, http_client=self.http_client, ) self.hs.config.perspectives = { - "persp_server": {"k": "v"} + self.mock_perspective_server.server_name: + self.mock_perspective_server.get_verify_keys() } + def check_context(self, _, expected): + self.assertEquals( + getattr(LoggingContext.current_context(), "test_key", None), + expected + ) + @defer.inlineCallbacks def test_wait_for_previous_lookups(self): sentinel_context = LoggingContext.current_context() kr = keyring.Keyring(self.hs) - def check_context(_, expected): - self.assertEquals( - LoggingContext.current_context().test_key, expected - ) - lookup_1_deferred = defer.Deferred() lookup_2_deferred = defer.Deferred() @@ -60,7 +93,7 @@ class KeyringTestCase(unittest.TestCase): self.assertTrue(wait_1_deferred.called) # ... so we should have preserved the LoggingContext. self.assertIs(LoggingContext.current_context(), context_one) - wait_1_deferred.addBoth(check_context, "one") + wait_1_deferred.addBoth(self.check_context, "one") with LoggingContext("two") as context_two: context_two.test_key = "two" @@ -74,7 +107,7 @@ class KeyringTestCase(unittest.TestCase): self.assertFalse(wait_2_deferred.called) # ... so we should have reset the LoggingContext. self.assertIs(LoggingContext.current_context(), sentinel_context) - wait_2_deferred.addBoth(check_context, "two") + wait_2_deferred.addBoth(self.check_context, "two") # let the first lookup complete (in the sentinel context) lookup_1_deferred.callback(None) @@ -89,38 +122,108 @@ class KeyringTestCase(unittest.TestCase): kr = keyring.Keyring(self.hs) json1 = {} - signedjson.sign.sign_json(json1, "server1", key1) + signedjson.sign.sign_json(json1, "server10", key1) + + persp_resp = { + "server_keys": [ + self.mock_perspective_server.get_signed_key( + "server10", + signedjson.key.get_verify_key(key1) + ), + ] + } + persp_deferred = defer.Deferred() - self.http_client.post_json.return_value = defer.Deferred() + @defer.inlineCallbacks + def get_perspectives(**kwargs): + self.assertEquals( + LoggingContext.current_context().test_key, "11", + ) + with logcontext.PreserveLoggingContext(): + yield persp_deferred + defer.returnValue(persp_resp) + self.http_client.post_json.side_effect = get_perspectives + + with LoggingContext("11") as context_11: + context_11.test_key = "11" + + # start off a first set of lookups + res_deferreds = kr.verify_json_objects_for_server( + [("server10", json1), + ("server11", {}) + ] + ) + + # the unsigned json should be rejected pretty quickly + self.assertTrue(res_deferreds[1].called) + try: + yield res_deferreds[1] + self.assertFalse("unsigned json didn't cause a failure") + except SynapseError: + pass + + self.assertFalse(res_deferreds[0].called) + res_deferreds[0].addBoth(self.check_context, None) + + # wait a tick for it to send the request to the perspectives server + # (it first tries the datastore) + yield async.sleep(0.005) + self.http_client.post_json.assert_called_once() + + self.assertIs(LoggingContext.current_context(), context_11) + + context_12 = LoggingContext("12") + context_12.test_key = "12" + with logcontext.PreserveLoggingContext(context_12): + # a second request for a server with outstanding requests + # should block rather than start a second call + self.http_client.post_json.reset_mock() + self.http_client.post_json.return_value = defer.Deferred() + + res_deferreds_2 = kr.verify_json_objects_for_server( + [("server10", json1)], + ) + yield async.sleep(0.005) + self.http_client.post_json.assert_not_called() + res_deferreds_2[0].addBoth(self.check_context, None) + + # complete the first request + with logcontext.PreserveLoggingContext(): + persp_deferred.callback(persp_resp) + self.assertIs(LoggingContext.current_context(), context_11) + + with logcontext.PreserveLoggingContext(): + yield res_deferreds[0] + yield res_deferreds_2[0] - # start off a first set of lookups - res_deferreds = kr.verify_json_objects_for_server( - [("server1", json1), - ("server2", {}) - ] + @defer.inlineCallbacks + def test_verify_json_for_server(self): + kr = keyring.Keyring(self.hs) + + key1 = signedjson.key.generate_signing_key(1) + yield self.hs.datastore.store_server_verify_key( + "server9", "", time.time() * 1000, + signedjson.key.get_verify_key(key1), ) + json1 = {} + signedjson.sign.sign_json(json1, "server9", key1) - # the unsigned json should be rejected pretty quickly - try: - yield res_deferreds[1] - self.assertFalse("unsigned json didn't cause a failure") - except SynapseError: - pass + sentinel_context = LoggingContext.current_context() - self.assertFalse(res_deferreds[0].called) + with LoggingContext("one") as context_one: + context_one.test_key = "one" - # wait a tick for it to send the request to the perspectives server - # (it first tries the datastore) - yield async.sleep(0.005) - self.http_client.post_json.assert_called_once() + defer = kr.verify_json_for_server("server9", {}) + try: + yield defer + self.fail("should fail on unsigned json") + except SynapseError: + pass + self.assertIs(LoggingContext.current_context(), context_one) - # a second request for a server with outstanding requests should - # block rather than start a second call - self.http_client.post_json.reset_mock() - self.http_client.post_json.return_value = defer.Deferred() + defer = kr.verify_json_for_server("server9", json1) + self.assertFalse(defer.called) + self.assertIs(LoggingContext.current_context(), sentinel_context) + yield defer - kr.verify_json_objects_for_server( - [("server1", json1)], - ) - yield async.sleep(0.005) - self.http_client.post_json.assert_not_called() + self.assertIs(LoggingContext.current_context(), context_one) -- cgit 1.5.1 From f65e31d22fe9a0b07053ee15004e106ca787048b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 22 Sep 2017 20:26:47 +0100 Subject: Do an AAAA lookup on SRV record targets (#2462) Support SRV records which point at AAAA records, as well as A records. Fixes https://github.com/matrix-org/synapse/issues/2405 --- synapse/http/endpoint.py | 116 +++++++++++++++++++++++++++++++++++++++-------- tests/test_dns.py | 26 +++++++++-- 2 files changed, 118 insertions(+), 24 deletions(-) (limited to 'tests') diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index d8923c9abb..241b17f2cb 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import socket from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer, reactor @@ -30,7 +31,10 @@ logger = logging.getLogger(__name__) SERVER_CACHE = {} - +# our record of an individual server which can be tried to reach a destination. +# +# "host" is actually a dotted-quad or ipv6 address string. Except when there's +# no SRV record, in which case it is the original hostname. _Server = collections.namedtuple( "_Server", "priority weight host port expires" ) @@ -219,9 +223,10 @@ class SRVClientEndpoint(object): return self.default_server else: raise ConnectError( - "Not server available for %s" % self.service_name + "No server available for %s" % self.service_name ) + # look for all servers with the same priority min_priority = self.servers[0].priority weight_indexes = list( (index, server.weight + 1) @@ -231,11 +236,22 @@ class SRVClientEndpoint(object): total_weight = sum(weight for index, weight in weight_indexes) target_weight = random.randint(0, total_weight) - for index, weight in weight_indexes: target_weight -= weight if target_weight <= 0: server = self.servers[index] + # XXX: this looks totally dubious: + # + # (a) we never reuse a server until we have been through + # all of the servers at the same priority, so if the + # weights are A: 100, B:1, we always do ABABAB instead of + # AAAA...AAAB (approximately). + # + # (b) After using all the servers at the lowest priority, + # we move onto the next priority. We should only use the + # second priority if servers at the top priority are + # unreachable. + # del self.servers[index] self.used_servers.append(server) return server @@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t continue payload = answer.payload - host = str(payload.target) - srv_ttl = answer.ttl - try: - answers, _, _ = yield dns_client.lookupAddress(host) - except DNSNameError: - continue + hosts = yield _get_hosts_for_srv_record( + dns_client, str(payload.target) + ) - for answer in answers: - if answer.type == dns.A and answer.payload: - ip = answer.payload.dottedQuad() - host_ttl = min(srv_ttl, answer.ttl) + for (ip, ttl) in hosts: + host_ttl = min(answer.ttl, ttl) - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + host_ttl, - )) + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) @@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t raise e defer.returnValue(servers) + + +@defer.inlineCallbacks +def _get_hosts_for_srv_record(dns_client, host): + """Look up each of the hosts in a SRV record + + Args: + dns_client (twisted.names.dns.IResolver): + host (basestring): host to look up + + Returns: + Deferred[list[(str, int)]]: a list of (host, ttl) pairs + + """ + ip4_servers = [] + ip6_servers = [] + + def cb(res): + # lookupAddress and lookupIP6Address return a three-tuple + # giving the answer, authority, and additional sections of the + # response. + # + # we only care about the answers. + + return res[0] + + def eb(res): + res.trap(DNSNameError) + return [] + + # no logcontexts here, so we can safely fire these off and gatherResults + d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb) + d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb) + results = yield defer.gatherResults([d1, d2], consumeErrors=True) + + for result in results: + for answer in result: + if not answer.payload: + continue + + try: + if answer.type == dns.A: + ip = answer.payload.dottedQuad() + ip4_servers.append((ip, answer.ttl)) + elif answer.type == dns.AAAA: + ip = socket.inet_ntop( + socket.AF_INET6, answer.payload.address, + ) + ip6_servers.append((ip, answer.ttl)) + else: + # the most likely candidate here is a CNAME record. + # rfc2782 says srvs may not point to aliases. + logger.warn( + "Ignoring unexpected DNS record type %s for %s", + answer.type, host, + ) + continue + except Exception as e: + logger.warn("Ignoring invalid DNS response for %s: %s", + host, e) + continue + + # keep the ipv4 results before the ipv6 results, mostly to match historical + # behaviour. + defer.returnValue(ip4_servers + ip6_servers) diff --git a/tests/test_dns.py b/tests/test_dns.py index c394c57ee7..d08b0f4333 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service from tests.utils import MockClock +@unittest.DEBUG class DnsTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve(self): dns_client_mock = Mock() - service_name = "test_service.examle.com" + service_name = "test_service.example.com" host_name = "example.com" ip_address = "127.0.0.1" + ip6_address = "::1" answer_srv = dns.RRHeader( type=dns.SRV, @@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase): ) ) - dns_client_mock.lookupService.return_value = ([answer_srv], None, None) - dns_client_mock.lookupAddress.return_value = ([answer_a], None, None) + answer_aaaa = dns.RRHeader( + type=dns.AAAA, + payload=dns.Record_AAAA( + address=ip6_address, + ) + ) + + dns_client_mock.lookupService.return_value = defer.succeed( + ([answer_srv], None, None), + ) + dns_client_mock.lookupAddress.return_value = defer.succeed( + ([answer_a], None, None), + ) + dns_client_mock.lookupIPV6Address.return_value = defer.succeed( + ([answer_aaaa], None, None), + ) cache = {} @@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase): dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupAddress.assert_called_once_with(host_name) + dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name) - self.assertEquals(len(servers), 1) + self.assertEquals(len(servers), 2) self.assertEquals(servers, cache[service_name]) self.assertEquals(servers[0].host, ip_address) + self.assertEquals(servers[1].host, ip6_address) @defer.inlineCallbacks def test_from_cache_expired_and_dns_fail(self): -- cgit 1.5.1