summary refs log tree commit diff
path: root/tests/http/federation
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/federation')
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py150
1 files changed, 146 insertions, 4 deletions
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 11ea8ef10c..fe459ea6e3 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -24,11 +24,16 @@ from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOpti
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.web.http import HTTPChannel
+from twisted.web.http_headers import Headers
 from twisted.web.iweb import IPolicyForHTTPS
 
 from synapse.crypto.context_factory import ClientTLSOptionsFactory
-from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.federation.matrix_federation_agent import (
+    MatrixFederationAgent,
+    _cache_period_from_headers,
+)
 from synapse.http.federation.srv_resolver import Server
+from synapse.util.caches.ttlcache import TTLCache
 from synapse.util.logcontext import LoggingContext
 
 from tests.http import ServerTLSContext
@@ -44,11 +49,14 @@ class MatrixFederationAgentTests(TestCase):
 
         self.mock_resolver = Mock()
 
+        self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+
         self.agent = MatrixFederationAgent(
             reactor=self.reactor,
             tls_client_options_factory=ClientTLSOptionsFactory(None),
             _well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
             _srv_resolver=self.mock_resolver,
+            _well_known_cache=self.well_known_cache,
         )
 
     def _make_connection(self, client_factory, expected_sni):
@@ -115,7 +123,9 @@ class MatrixFederationAgentTests(TestCase):
             finally:
                 _check_logcontext(context)
 
-    def _handle_well_known_connection(self, client_factory, expected_sni, target_server):
+    def _handle_well_known_connection(
+        self, client_factory, expected_sni, target_server, response_headers={},
+    ):
         """Handle an outgoing HTTPs connection: wire it up to a server, check that the
         request is for a .well-known, and send the response.
 
@@ -124,6 +134,8 @@ class MatrixFederationAgentTests(TestCase):
             expected_sni (bytes): SNI that we expect the outgoing connection to send
             target_server (bytes): target server that we should redirect to in the
                 .well-known response.
+        Returns:
+            HTTPChannel: server impl
         """
         # make the connection for .well-known
         well_known_server = self._make_connection(
@@ -133,9 +145,10 @@ class MatrixFederationAgentTests(TestCase):
         # check the .well-known request and send a response
         self.assertEqual(len(well_known_server.requests), 1)
         request = well_known_server.requests[0]
-        self._send_well_known_response(request, target_server)
+        self._send_well_known_response(request, target_server, headers=response_headers)
+        return well_known_server
 
-    def _send_well_known_response(self, request, target_server):
+    def _send_well_known_response(self, request, target_server, headers={}):
         """Check that an incoming request looks like a valid .well-known request, and
         send back the response.
         """
@@ -146,6 +159,8 @@ class MatrixFederationAgentTests(TestCase):
             [b'testserv'],
         )
         # send back a response
+        for k, v in headers.items():
+            request.setHeader(k, v)
         request.write(b'{ "m.server": "%s" }' % (target_server,))
         request.finish()
 
@@ -448,6 +463,13 @@ class MatrixFederationAgentTests(TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
+        self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
+
+        # check the cache expires
+        self.reactor.pump((25 * 3600,))
+        self.well_known_cache.expire()
+        self.assertNotIn(b"testserv", self.well_known_cache)
+
     def test_get_hostname_srv(self):
         """
         Test the behaviour when there is a single SRV record
@@ -661,6 +683,126 @@ class MatrixFederationAgentTests(TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
+    @defer.inlineCallbacks
+    def do_get_well_known(self, serv):
+        try:
+            result = yield self.agent._get_well_known(serv)
+            logger.info("Result from well-known fetch: %s", result)
+        except Exception as e:
+            logger.warning("Error fetching well-known: %s", e)
+            raise
+        defer.returnValue(result)
+
+    def test_well_known_cache(self):
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        fetch_d = self.do_get_well_known(b'testserv')
+
+        # there should be an attempt to connect on port 443 for the .well-known
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(port, 443)
+
+        well_known_server = self._handle_well_known_connection(
+            client_factory,
+            expected_sni=b"testserv",
+            response_headers={b'Cache-Control': b'max-age=10'},
+            target_server=b"target-server",
+        )
+
+        r = self.successResultOf(fetch_d)
+        self.assertEqual(r, b'target-server')
+
+        # close the tcp connection
+        well_known_server.loseConnection()
+
+        # repeat the request: it should hit the cache
+        fetch_d = self.do_get_well_known(b'testserv')
+        r = self.successResultOf(fetch_d)
+        self.assertEqual(r, b'target-server')
+
+        # expire the cache
+        self.reactor.pump((10.0,))
+
+        # now it should connect again
+        fetch_d = self.do_get_well_known(b'testserv')
+
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, '1.2.3.4')
+        self.assertEqual(port, 443)
+
+        self._handle_well_known_connection(
+            client_factory,
+            expected_sni=b"testserv",
+            target_server=b"other-server",
+        )
+
+        r = self.successResultOf(fetch_d)
+        self.assertEqual(r, b'other-server')
+
+
+class TestCachePeriodFromHeaders(TestCase):
+    def test_cache_control(self):
+        # uppercase
+        self.assertEqual(
+            _cache_period_from_headers(
+                Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}),
+            ), 100,
+        )
+
+        # missing value
+        self.assertIsNone(_cache_period_from_headers(
+            Headers({b'Cache-Control': [b'max-age=, bar']}),
+        ))
+
+        # hackernews: bogus due to semicolon
+        self.assertIsNone(_cache_period_from_headers(
+            Headers({b'Cache-Control': [b'private; max-age=0']}),
+        ))
+
+        # github
+        self.assertEqual(
+            _cache_period_from_headers(
+                Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}),
+            ), 0,
+        )
+
+        # google
+        self.assertEqual(
+            _cache_period_from_headers(
+                Headers({b'cache-control': [b'private, max-age=0']}),
+            ), 0,
+        )
+
+    def test_expires(self):
+        self.assertEqual(
+            _cache_period_from_headers(
+                Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
+                time_now=lambda: 1548833700
+            ), 33,
+        )
+
+        # cache-control overrides expires
+        self.assertEqual(
+            _cache_period_from_headers(
+                Headers({
+                    b'cache-control': [b'max-age=10'],
+                    b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']
+                }),
+                time_now=lambda: 1548833700
+            ), 10,
+        )
+
+        # invalid expires means immediate expiry
+        self.assertEqual(
+            _cache_period_from_headers(
+                Headers({b'Expires': [b'0']}),
+            ), 0,
+        )
+
 
 def _check_logcontext(context):
     current = LoggingContext.current_context()