diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 11ea8ef10c..fe459ea6e3 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -24,11 +24,16 @@ from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOpti
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.web.http import HTTPChannel
+from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS
from synapse.crypto.context_factory import ClientTLSOptionsFactory
-from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.federation.matrix_federation_agent import (
+ MatrixFederationAgent,
+ _cache_period_from_headers,
+)
from synapse.http.federation.srv_resolver import Server
+from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import LoggingContext
from tests.http import ServerTLSContext
@@ -44,11 +49,14 @@ class MatrixFederationAgentTests(TestCase):
self.mock_resolver = Mock()
+ self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(None),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver,
+ _well_known_cache=self.well_known_cache,
)
def _make_connection(self, client_factory, expected_sni):
@@ -115,7 +123,9 @@ class MatrixFederationAgentTests(TestCase):
finally:
_check_logcontext(context)
- def _handle_well_known_connection(self, client_factory, expected_sni, target_server):
+ def _handle_well_known_connection(
+ self, client_factory, expected_sni, target_server, response_headers={},
+ ):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response.
@@ -124,6 +134,8 @@ class MatrixFederationAgentTests(TestCase):
expected_sni (bytes): SNI that we expect the outgoing connection to send
target_server (bytes): target server that we should redirect to in the
.well-known response.
+ Returns:
+ HTTPChannel: server impl
"""
# make the connection for .well-known
well_known_server = self._make_connection(
@@ -133,9 +145,10 @@ class MatrixFederationAgentTests(TestCase):
# check the .well-known request and send a response
self.assertEqual(len(well_known_server.requests), 1)
request = well_known_server.requests[0]
- self._send_well_known_response(request, target_server)
+ self._send_well_known_response(request, target_server, headers=response_headers)
+ return well_known_server
- def _send_well_known_response(self, request, target_server):
+ def _send_well_known_response(self, request, target_server, headers={}):
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
"""
@@ -146,6 +159,8 @@ class MatrixFederationAgentTests(TestCase):
[b'testserv'],
)
# send back a response
+ for k, v in headers.items():
+ request.setHeader(k, v)
request.write(b'{ "m.server": "%s" }' % (target_server,))
request.finish()
@@ -448,6 +463,13 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
+ self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
+
+ # check the cache expires
+ self.reactor.pump((25 * 3600,))
+ self.well_known_cache.expire()
+ self.assertNotIn(b"testserv", self.well_known_cache)
+
def test_get_hostname_srv(self):
"""
Test the behaviour when there is a single SRV record
@@ -661,6 +683,126 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
+ @defer.inlineCallbacks
+ def do_get_well_known(self, serv):
+ try:
+ result = yield self.agent._get_well_known(serv)
+ logger.info("Result from well-known fetch: %s", result)
+ except Exception as e:
+ logger.warning("Error fetching well-known: %s", e)
+ raise
+ defer.returnValue(result)
+
+ def test_well_known_cache(self):
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ fetch_d = self.do_get_well_known(b'testserv')
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ well_known_server = self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ response_headers={b'Cache-Control': b'max-age=10'},
+ target_server=b"target-server",
+ )
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r, b'target-server')
+
+ # close the tcp connection
+ well_known_server.loseConnection()
+
+ # repeat the request: it should hit the cache
+ fetch_d = self.do_get_well_known(b'testserv')
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r, b'target-server')
+
+ # expire the cache
+ self.reactor.pump((10.0,))
+
+ # now it should connect again
+ fetch_d = self.do_get_well_known(b'testserv')
+
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ target_server=b"other-server",
+ )
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r, b'other-server')
+
+
+class TestCachePeriodFromHeaders(TestCase):
+ def test_cache_control(self):
+ # uppercase
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}),
+ ), 100,
+ )
+
+ # missing value
+ self.assertIsNone(_cache_period_from_headers(
+ Headers({b'Cache-Control': [b'max-age=, bar']}),
+ ))
+
+ # hackernews: bogus due to semicolon
+ self.assertIsNone(_cache_period_from_headers(
+ Headers({b'Cache-Control': [b'private; max-age=0']}),
+ ))
+
+ # github
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}),
+ ), 0,
+ )
+
+ # google
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'cache-control': [b'private, max-age=0']}),
+ ), 0,
+ )
+
+ def test_expires(self):
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
+ time_now=lambda: 1548833700
+ ), 33,
+ )
+
+ # cache-control overrides expires
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({
+ b'cache-control': [b'max-age=10'],
+ b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']
+ }),
+ time_now=lambda: 1548833700
+ ), 10,
+ )
+
+ # invalid expires means immediate expiry
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Expires': [b'0']}),
+ ), 0,
+ )
+
def _check_logcontext(context):
current = LoggingContext.current_context()
|