From f299c5414c2dd300103b0e11e7114123d8eb58a1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 8 Aug 2019 15:30:04 +0100 Subject: Refactor MatrixFederationAgent to retry SRV. This refactors MatrixFederationAgent to move the SRV lookup into the endpoint code, this has two benefits: 1. Its easier to retry different host/ports in the same way as HostnameEndpoint. 2. We avoid SRV lookups if we have a free connection in the pool --- .../federation/test_matrix_federation_agent.py | 63 ++++++++++++++++++++-- tests/http/federation/test_srv_resolver.py | 8 ++- 2 files changed, 66 insertions(+), 5 deletions(-) (limited to 'tests/http/federation') diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 2c568788b3..f97c8a59f6 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -41,9 +41,9 @@ from synapse.http.federation.well_known_resolver import ( from synapse.logging.context import LoggingContext from synapse.util.caches.ttlcache import TTLCache +from tests import unittest from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.server import FakeTransport, ThreadedMemoryReactorClock -from tests.unittest import TestCase from tests.utils import default_config logger = logging.getLogger(__name__) @@ -67,7 +67,8 @@ def get_connection_factory(): return test_server_connection_factory -class MatrixFederationAgentTests(TestCase): +@unittest.DEBUG +class MatrixFederationAgentTests(unittest.TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() @@ -1056,8 +1057,64 @@ class MatrixFederationAgentTests(TestCase): r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, None) + def test_srv_fallbacks(self): + """Test that other SRV results are tried if the first one fails. + """ + + self.mock_resolver.resolve_service.side_effect = lambda _: [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + self.reactor.lookups["target.com"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix._tcp.testserv" + ) + + # We should see an attempt to connect to the first server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8443) + + # Fonx the connection + client_factory.clientConnectionFailed(None, Exception("nope")) + + # There's a 300ms delay in HostnameEndpoint + self.reactor.pump((0.4,)) + + # Hasn't failed yet + self.assertNoResult(test_d) + + # We shouldnow see an attempt to connect to the second server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8444) + + # make a test server, and wire up the client + http_server = self._make_connection(client_factory, expected_sni=b"testserv") + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + -class TestCachePeriodFromHeaders(TestCase): +class TestCachePeriodFromHeaders(unittest.TestCase): def test_cache_control(self): # uppercase self.assertEqual( diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index 3b885ef64b..df034ab237 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 0 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) @@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase): service_name = b"test_service.example.com" - entry = Mock(spec_set=["expires"]) + entry = Mock(spec_set=["expires", "priority", "weight"]) entry.expires = 999999999 + entry.priority = 0 + entry.weight = 0 cache = {service_name: [entry]} resolver = SrvResolver( -- cgit 1.4.1