diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 1435baede2..71d7025264 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -41,9 +41,9 @@ from synapse.http.federation.well_known_resolver import (
from synapse.logging.context import LoggingContext
from synapse.util.caches.ttlcache import TTLCache
+from tests import unittest
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.unittest import TestCase
from tests.utils import default_config
logger = logging.getLogger(__name__)
@@ -67,14 +67,12 @@ def get_connection_factory():
return test_server_connection_factory
-class MatrixFederationAgentTests(TestCase):
+class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
self.mock_resolver = Mock()
- self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
-
config_dict = default_config("test", parse=False)
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
@@ -82,11 +80,21 @@ class MatrixFederationAgentTests(TestCase):
config.parse_config_dict(config_dict, "", "")
self.tls_factory = ClientTLSOptionsFactory(config)
+
+ self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+ self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+ self.well_known_resolver = WellKnownResolver(
+ self.reactor,
+ Agent(self.reactor, contextFactory=self.tls_factory),
+ well_known_cache=self.well_known_cache,
+ had_well_known_cache=self.had_well_known_cache,
+ )
+
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=self.tls_factory,
_srv_resolver=self.mock_resolver,
- _well_known_cache=self.well_known_cache,
+ _well_known_resolver=self.well_known_resolver,
)
def _make_connection(self, client_factory, expected_sni):
@@ -543,7 +551,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
# check the cache expires
- self.reactor.pump((25 * 3600,))
+ self.reactor.pump((48 * 3600,))
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
@@ -631,7 +639,7 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
# check the cache expires
- self.reactor.pump((25 * 3600,))
+ self.reactor.pump((48 * 3600,))
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
@@ -701,11 +709,18 @@ class MatrixFederationAgentTests(TestCase):
config = default_config("test", parse=True)
+ # Build a new agent and WellKnownResolver with a different tls factory
+ tls_factory = ClientTLSOptionsFactory(config)
agent = MatrixFederationAgent(
reactor=self.reactor,
- tls_client_options_factory=ClientTLSOptionsFactory(config),
+ tls_client_options_factory=tls_factory,
_srv_resolver=self.mock_resolver,
- _well_known_cache=self.well_known_cache,
+ _well_known_resolver=WellKnownResolver(
+ self.reactor,
+ Agent(self.reactor, contextFactory=tls_factory),
+ well_known_cache=self.well_known_cache,
+ had_well_known_cache=self.had_well_known_cache,
+ ),
)
test_d = agent.request(b"GET", b"matrix://testserv/foo/bar")
@@ -932,15 +947,9 @@ class MatrixFederationAgentTests(TestCase):
self.successResultOf(test_d)
def test_well_known_cache(self):
- well_known_resolver = WellKnownResolver(
- self.reactor,
- Agent(self.reactor, contextFactory=self.tls_factory),
- well_known_cache=self.well_known_cache,
- )
-
self.reactor.lookups["testserv"] = "1.2.3.4"
- fetch_d = well_known_resolver.get_well_known(b"testserv")
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -963,7 +972,7 @@ class MatrixFederationAgentTests(TestCase):
well_known_server.loseConnection()
# repeat the request: it should hit the cache
- fetch_d = well_known_resolver.get_well_known(b"testserv")
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
@@ -971,7 +980,7 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((1000.0,))
# now it should connect again
- fetch_d = well_known_resolver.get_well_known(b"testserv")
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -987,8 +996,137 @@ class MatrixFederationAgentTests(TestCase):
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"other-server")
+ def test_well_known_cache_with_temp_failure(self):
+ """Test that we refetch well-known before the cache expires, and that
+ it ignores transient errors.
+ """
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 443)
+
+ well_known_server = self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ response_headers={b"Cache-Control": b"max-age=1000"},
+ content=b'{ "m.server": "target-server" }',
+ )
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r.delegated_server, b"target-server")
+
+ # close the tcp connection
+ well_known_server.loseConnection()
+
+ # Get close to the cache expiry, this will cause the resolver to do
+ # another lookup.
+ self.reactor.pump((900.0,))
+
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+
+ # The resolver may retry a few times, so fonx all requests that come along
+ attempts = 0
+ while self.reactor.tcpClients:
+ clients = self.reactor.tcpClients
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+
+ attempts += 1
+
+ # fonx the connection attempt, this will be treated as a temporary
+ # failure.
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a few sleeps involved, so we have to pump the reactor a
+ # bit.
+ self.reactor.pump((1.0, 1.0))
+
+ # We expect to see more than one attempt as there was previously a valid
+ # well known.
+ self.assertGreater(attempts, 1)
+
+ # Resolver should return cached value, despite the lookup failing.
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r.delegated_server, b"target-server")
+
+ # Expire both caches and repeat the request
+ self.reactor.pump((10000.0,))
+
+ # Repated the request, this time it should fail if the lookup fails.
+ fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+
+ clients = self.reactor.tcpClients
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+ self.reactor.pump((0.4,))
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r.delegated_server, None)
+
+ def test_srv_fallbacks(self):
+ """Test that other SRV results are tried if the first one fails.
+ """
+
+ self.mock_resolver.resolve_service.side_effect = lambda _: [
+ Server(host=b"target.com", port=8443),
+ Server(host=b"target.com", port=8444),
+ ]
+ self.reactor.lookups["target.com"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://testserv/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once_with(
+ b"_matrix._tcp.testserv"
+ )
+
+ # We should see an attempt to connect to the first server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8443)
+
+ # Fonx the connection
+ client_factory.clientConnectionFailed(None, Exception("nope"))
+
+ # There's a 300ms delay in HostnameEndpoint
+ self.reactor.pump((0.4,))
+
+ # Hasn't failed yet
+ self.assertNoResult(test_d)
+
+ # We shouldnow see an attempt to connect to the second server
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8444)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(client_factory, expected_sni=b"testserv")
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b"GET")
+ self.assertEqual(request.path, b"/foo/bar")
+ self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
-class TestCachePeriodFromHeaders(TestCase):
+class TestCachePeriodFromHeaders(unittest.TestCase):
def test_cache_control(self):
# uppercase
self.assertEqual(
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 3b885ef64b..df034ab237 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -83,8 +83,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
- entry = Mock(spec_set=["expires"])
+ entry = Mock(spec_set=["expires", "priority", "weight"])
entry.expires = 0
+ entry.priority = 0
+ entry.weight = 0
cache = {service_name: [entry]}
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
@@ -105,8 +107,10 @@ class SrvResolverTestCase(unittest.TestCase):
service_name = b"test_service.example.com"
- entry = Mock(spec_set=["expires"])
+ entry = Mock(spec_set=["expires", "priority", "weight"])
entry.expires = 999999999
+ entry.priority = 0
+ entry.weight = 0
cache = {service_name: [entry]}
resolver = SrvResolver(
|