summary refs log tree commit diff
path: root/tests/http
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http')
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py211
-rw-r--r--tests/http/federation/test_srv_resolver.py8
2 files changed, 180 insertions, 39 deletions
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index b906686b49..71d7025264 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -25,23 +25,25 @@ from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOpti
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.web._newclient import ResponseNeverReceived
+from twisted.web.client import Agent
 from twisted.web.http import HTTPChannel
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IPolicyForHTTPS
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto.context_factory import ClientTLSOptionsFactory
-from synapse.http.federation.matrix_federation_agent import (
-    MatrixFederationAgent,
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.federation.srv_resolver import Server
+from synapse.http.federation.well_known_resolver import (
+    WellKnownResolver,
     _cache_period_from_headers,
 )
-from synapse.http.federation.srv_resolver import Server
 from synapse.logging.context import LoggingContext
 from synapse.util.caches.ttlcache import TTLCache
 
+from tests import unittest
 from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
 from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.unittest import TestCase
 from tests.utils import default_config
 
 logger = logging.getLogger(__name__)
@@ -65,27 +67,34 @@ def get_connection_factory():
     return test_server_connection_factory
 
 
-class MatrixFederationAgentTests(TestCase):
+class MatrixFederationAgentTests(unittest.TestCase):
     def setUp(self):
         self.reactor = ThreadedMemoryReactorClock()
 
         self.mock_resolver = Mock()
 
-        self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
-
         config_dict = default_config("test", parse=False)
         config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
-        # config_dict["trusted_key_servers"] = []
 
         self._config = config = HomeServerConfig()
         config.parse_config_dict(config_dict, "", "")
 
+        self.tls_factory = ClientTLSOptionsFactory(config)
+
+        self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+        self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+        self.well_known_resolver = WellKnownResolver(
+            self.reactor,
+            Agent(self.reactor, contextFactory=self.tls_factory),
+            well_known_cache=self.well_known_cache,
+            had_well_known_cache=self.had_well_known_cache,
+        )
+
         self.agent = MatrixFederationAgent(
             reactor=self.reactor,
-            tls_client_options_factory=ClientTLSOptionsFactory(config),
-            _well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
+            tls_client_options_factory=self.tls_factory,
             _srv_resolver=self.mock_resolver,
-            _well_known_cache=self.well_known_cache,
+            _well_known_resolver=self.well_known_resolver,
         )
 
     def _make_connection(self, client_factory, expected_sni):
@@ -542,7 +551,7 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
 
         # check the cache expires
-        self.reactor.pump((25 * 3600,))
+        self.reactor.pump((48 * 3600,))
         self.well_known_cache.expire()
         self.assertNotIn(b"testserv", self.well_known_cache)
 
@@ -630,7 +639,7 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
 
         # check the cache expires
-        self.reactor.pump((25 * 3600,))
+        self.reactor.pump((48 * 3600,))
         self.well_known_cache.expire()
         self.assertNotIn(b"testserv", self.well_known_cache)
 
@@ -691,18 +700,27 @@ class MatrixFederationAgentTests(TestCase):
         not signed by a CA
         """
 
-        # we use the same test server as the other tests, but use an agent
-        # with _well_known_tls_policy left to the default, which will not
-        # trust it (since the presented cert is signed by a test CA)
+        # we use the same test server as the other tests, but use an agent with
+        # the config left to the default, which will not trust it (since the
+        # presented cert is signed by a test CA)
 
         self.mock_resolver.resolve_service.side_effect = lambda _: []
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
+        config = default_config("test", parse=True)
+
+        # Build a new agent and WellKnownResolver with a different tls factory
+        tls_factory = ClientTLSOptionsFactory(config)
         agent = MatrixFederationAgent(
             reactor=self.reactor,
-            tls_client_options_factory=ClientTLSOptionsFactory(self._config),
+            tls_client_options_factory=tls_factory,
             _srv_resolver=self.mock_resolver,
-            _well_known_cache=self.well_known_cache,
+            _well_known_resolver=WellKnownResolver(
+                self.reactor,
+                Agent(self.reactor, contextFactory=tls_factory),
+                well_known_cache=self.well_known_cache,
+                had_well_known_cache=self.had_well_known_cache,
+            ),
         )
 
         test_d = agent.request(b"GET", b"matrix://testserv/foo/bar")
@@ -928,20 +946,10 @@ class MatrixFederationAgentTests(TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    @defer.inlineCallbacks
-    def do_get_well_known(self, serv):
-        try:
-            result = yield self.agent._get_well_known(serv)
-            logger.info("Result from well-known fetch: %s", result)
-        except Exception as e:
-            logger.warning("Error fetching well-known: %s", e)
-            raise
-        return result
-
     def test_well_known_cache(self):
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
-        fetch_d = self.do_get_well_known(b"testserv")
+        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
 
         # there should be an attempt to connect on port 443 for the .well-known
         clients = self.reactor.tcpClients
@@ -953,26 +961,26 @@ class MatrixFederationAgentTests(TestCase):
         well_known_server = self._handle_well_known_connection(
             client_factory,
             expected_sni=b"testserv",
-            response_headers={b"Cache-Control": b"max-age=10"},
+            response_headers={b"Cache-Control": b"max-age=1000"},
             content=b'{ "m.server": "target-server" }',
         )
 
         r = self.successResultOf(fetch_d)
-        self.assertEqual(r, b"target-server")
+        self.assertEqual(r.delegated_server, b"target-server")
 
         # close the tcp connection
         well_known_server.loseConnection()
 
         # repeat the request: it should hit the cache
-        fetch_d = self.do_get_well_known(b"testserv")
+        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
         r = self.successResultOf(fetch_d)
-        self.assertEqual(r, b"target-server")
+        self.assertEqual(r.delegated_server, b"target-server")
 
         # expire the cache
-        self.reactor.pump((10.0,))
+        self.reactor.pump((1000.0,))
 
         # now it should connect again
-        fetch_d = self.do_get_well_known(b"testserv")
+        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
 
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -986,10 +994,139 @@ class MatrixFederationAgentTests(TestCase):
         )
 
         r = self.successResultOf(fetch_d)
-        self.assertEqual(r, b"other-server")
+        self.assertEqual(r.delegated_server, b"other-server")
+
+    def test_well_known_cache_with_temp_failure(self):
+        """Test that we refetch well-known before the cache expires, and that
+        it ignores transient errors.
+        """
+
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+
+        # there should be an attempt to connect on port 443 for the .well-known
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 443)
+
+        well_known_server = self._handle_well_known_connection(
+            client_factory,
+            expected_sni=b"testserv",
+            response_headers={b"Cache-Control": b"max-age=1000"},
+            content=b'{ "m.server": "target-server" }',
+        )
+
+        r = self.successResultOf(fetch_d)
+        self.assertEqual(r.delegated_server, b"target-server")
+
+        # close the tcp connection
+        well_known_server.loseConnection()
+
+        # Get close to the cache expiry, this will cause the resolver to do
+        # another lookup.
+        self.reactor.pump((900.0,))
+
+        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+
+        # The resolver may retry a few times, so fonx all requests that come along
+        attempts = 0
+        while self.reactor.tcpClients:
+            clients = self.reactor.tcpClients
+            (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+
+            attempts += 1
+
+            # fonx the connection attempt, this will be treated as a temporary
+            # failure.
+            client_factory.clientConnectionFailed(None, Exception("nope"))
+
+            # There's a few sleeps involved, so we have to pump the reactor a
+            # bit.
+            self.reactor.pump((1.0, 1.0))
+
+        # We expect to see more than one attempt as there was previously a valid
+        # well known.
+        self.assertGreater(attempts, 1)
+
+        # Resolver should return cached value, despite the lookup failing.
+        r = self.successResultOf(fetch_d)
+        self.assertEqual(r.delegated_server, b"target-server")
+
+        # Expire both caches and repeat the request
+        self.reactor.pump((10000.0,))
+
+        # Repated the request, this time it should fail if the lookup fails.
+        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+
+        clients = self.reactor.tcpClients
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        client_factory.clientConnectionFailed(None, Exception("nope"))
+        self.reactor.pump((0.4,))
+
+        r = self.successResultOf(fetch_d)
+        self.assertEqual(r.delegated_server, None)
+
+    def test_srv_fallbacks(self):
+        """Test that other SRV results are tried if the first one fails.
+        """
+
+        self.mock_resolver.resolve_service.side_effect = lambda _: [
+            Server(host=b"target.com", port=8443),
+            Server(host=b"target.com", port=8444),
+        ]
+        self.reactor.lookups["target.com"] = "1.2.3.4"
+
+        test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        self.mock_resolver.resolve_service.assert_called_once_with(
+            b"_matrix._tcp.testserv"
+        )
+
+        # We should see an attempt to connect to the first server
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8443)
+
+        # Fonx the connection
+        client_factory.clientConnectionFailed(None, Exception("nope"))
+
+        # There's a 300ms delay in HostnameEndpoint
+        self.reactor.pump((0.4,))
+
+        # Hasn't failed yet
+        self.assertNoResult(test_d)
+
+        # We shouldnow see an attempt to connect to the second server
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8444)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+        self.assertEqual(len(http_server.requests), 1)
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/foo/bar")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+        # finish the request
+        request.finish()
+        self.reactor.pump((0.1,))
+        self.successResultOf(test_d)
 
 
-class TestCachePeriodFromHeaders(TestCase):
+class TestCachePeriodFromHeaders(unittest.TestCase):
     def test_cache_control(self):
         # uppercase
         self.assertEqual(
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 3b885ef64b..df034ab237 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase):
 
         service_name = b"test_service.example.com"
 
-        entry = Mock(spec_set=["expires"])
+        entry = Mock(spec_set=["expires", "priority", "weight"])
         entry.expires = 0
+        entry.priority = 0
+        entry.weight = 0
 
         cache = {service_name: [entry]}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase):
 
         service_name = b"test_service.example.com"
 
-        entry = Mock(spec_set=["expires"])
+        entry = Mock(spec_set=["expires", "priority", "weight"])
         entry.expires = 999999999
+        entry.priority = 0
+        entry.weight = 0
 
         cache = {service_name: [entry]}
         resolver = SrvResolver(