diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/http/federation/test_matrix_federation_agent.py | 150 | ||||
-rw-r--r-- | tests/server.py | 18 | ||||
-rw-r--r-- | tests/util/caches/test_ttlcache.py | 83 |
3 files changed, 243 insertions, 8 deletions
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 11ea8ef10c..fe459ea6e3 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -24,11 +24,16 @@ from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOpti from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.web.http import HTTPChannel +from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS from synapse.crypto.context_factory import ClientTLSOptionsFactory -from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent +from synapse.http.federation.matrix_federation_agent import ( + MatrixFederationAgent, + _cache_period_from_headers, +) from synapse.http.federation.srv_resolver import Server +from synapse.util.caches.ttlcache import TTLCache from synapse.util.logcontext import LoggingContext from tests.http import ServerTLSContext @@ -44,11 +49,14 @@ class MatrixFederationAgentTests(TestCase): self.mock_resolver = Mock() + self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) + self.agent = MatrixFederationAgent( reactor=self.reactor, tls_client_options_factory=ClientTLSOptionsFactory(None), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, + _well_known_cache=self.well_known_cache, ) def _make_connection(self, client_factory, expected_sni): @@ -115,7 +123,9 @@ class MatrixFederationAgentTests(TestCase): finally: _check_logcontext(context) - def _handle_well_known_connection(self, client_factory, expected_sni, target_server): + def _handle_well_known_connection( + self, client_factory, expected_sni, target_server, response_headers={}, + ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -124,6 +134,8 @@ class MatrixFederationAgentTests(TestCase): expected_sni (bytes): SNI that we expect the outgoing connection to send target_server (bytes): target server that we should redirect to in the .well-known response. + Returns: + HTTPChannel: server impl """ # make the connection for .well-known well_known_server = self._make_connection( @@ -133,9 +145,10 @@ class MatrixFederationAgentTests(TestCase): # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) request = well_known_server.requests[0] - self._send_well_known_response(request, target_server) + self._send_well_known_response(request, target_server, headers=response_headers) + return well_known_server - def _send_well_known_response(self, request, target_server): + def _send_well_known_response(self, request, target_server, headers={}): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ @@ -146,6 +159,8 @@ class MatrixFederationAgentTests(TestCase): [b'testserv'], ) # send back a response + for k, v in headers.items(): + request.setHeader(k, v) request.write(b'{ "m.server": "%s" }' % (target_server,)) request.finish() @@ -448,6 +463,13 @@ class MatrixFederationAgentTests(TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) + self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") + + # check the cache expires + self.reactor.pump((25 * 3600,)) + self.well_known_cache.expire() + self.assertNotIn(b"testserv", self.well_known_cache) + def test_get_hostname_srv(self): """ Test the behaviour when there is a single SRV record @@ -661,6 +683,126 @@ class MatrixFederationAgentTests(TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) + @defer.inlineCallbacks + def do_get_well_known(self, serv): + try: + result = yield self.agent._get_well_known(serv) + logger.info("Result from well-known fetch: %s", result) + except Exception as e: + logger.warning("Error fetching well-known: %s", e) + raise + defer.returnValue(result) + + def test_well_known_cache(self): + self.reactor.lookups["testserv"] = "1.2.3.4" + + fetch_d = self.do_get_well_known(b'testserv') + + # there should be an attempt to connect on port 443 for the .well-known + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 443) + + well_known_server = self._handle_well_known_connection( + client_factory, + expected_sni=b"testserv", + response_headers={b'Cache-Control': b'max-age=10'}, + target_server=b"target-server", + ) + + r = self.successResultOf(fetch_d) + self.assertEqual(r, b'target-server') + + # close the tcp connection + well_known_server.loseConnection() + + # repeat the request: it should hit the cache + fetch_d = self.do_get_well_known(b'testserv') + r = self.successResultOf(fetch_d) + self.assertEqual(r, b'target-server') + + # expire the cache + self.reactor.pump((10.0,)) + + # now it should connect again + fetch_d = self.do_get_well_known(b'testserv') + + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 443) + + self._handle_well_known_connection( + client_factory, + expected_sni=b"testserv", + target_server=b"other-server", + ) + + r = self.successResultOf(fetch_d) + self.assertEqual(r, b'other-server') + + +class TestCachePeriodFromHeaders(TestCase): + def test_cache_control(self): + # uppercase + self.assertEqual( + _cache_period_from_headers( + Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}), + ), 100, + ) + + # missing value + self.assertIsNone(_cache_period_from_headers( + Headers({b'Cache-Control': [b'max-age=, bar']}), + )) + + # hackernews: bogus due to semicolon + self.assertIsNone(_cache_period_from_headers( + Headers({b'Cache-Control': [b'private; max-age=0']}), + )) + + # github + self.assertEqual( + _cache_period_from_headers( + Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}), + ), 0, + ) + + # google + self.assertEqual( + _cache_period_from_headers( + Headers({b'cache-control': [b'private, max-age=0']}), + ), 0, + ) + + def test_expires(self): + self.assertEqual( + _cache_period_from_headers( + Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), + time_now=lambda: 1548833700 + ), 33, + ) + + # cache-control overrides expires + self.assertEqual( + _cache_period_from_headers( + Headers({ + b'cache-control': [b'max-age=10'], + b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'] + }), + time_now=lambda: 1548833700 + ), 10, + ) + + # invalid expires means immediate expiry + self.assertEqual( + _cache_period_from_headers( + Headers({b'Expires': [b'0']}), + ), 0, + ) + def _check_logcontext(context): current = LoggingContext.current_context() diff --git a/tests/server.py b/tests/server.py index 3d7ae9875c..fc1e76d146 100644 --- a/tests/server.py +++ b/tests/server.py @@ -360,6 +360,7 @@ class FakeTransport(object): """ disconnecting = False + disconnected = False buffer = attr.ib(default=b'') producer = attr.ib(default=None) @@ -370,14 +371,16 @@ class FakeTransport(object): return None def loseConnection(self, reason=None): - logger.info("FakeTransport: loseConnection(%s)", reason) if not self.disconnecting: + logger.info("FakeTransport: loseConnection(%s)", reason) self.disconnecting = True if self._protocol: self._protocol.connectionLost(reason) + self.disconnected = True def abortConnection(self): - self.disconnecting = True + logger.info("FakeTransport: abortConnection()") + self.loseConnection() def pauseProducing(self): if not self.producer: @@ -416,9 +419,16 @@ class FakeTransport(object): # TLSMemoryBIOProtocol return + if self.disconnected: + return + logger.info("%s->%s: %s", self._protocol, self.other, self.buffer) + if getattr(self.other, "transport") is not None: - self.other.dataReceived(self.buffer) - self.buffer = b"" + try: + self.other.dataReceived(self.buffer) + self.buffer = b"" + except Exception as e: + logger.warning("Exception writing to protocol: %s", e) return self._reactor.callLater(0.0, _write) diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py new file mode 100644 index 0000000000..03b3c15db6 --- /dev/null +++ b/tests/util/caches/test_ttlcache.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from synapse.util.caches.ttlcache import TTLCache + +from tests import unittest + + +class CacheTestCase(unittest.TestCase): + def setUp(self): + self.mock_timer = Mock(side_effect=lambda: 100.0) + self.cache = TTLCache("test_cache", self.mock_timer) + + def test_get(self): + """simple set/get tests""" + self.cache.set('one', '1', 10) + self.cache.set('two', '2', 20) + self.cache.set('three', '3', 30) + + self.assertEqual(len(self.cache), 3) + + self.assertTrue('one' in self.cache) + self.assertEqual(self.cache.get('one'), '1') + self.assertEqual(self.cache['one'], '1') + self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110)) + self.assertEqual(self.cache._metrics.hits, 3) + self.assertEqual(self.cache._metrics.misses, 0) + + self.cache.set('two', '2.5', 20) + self.assertEqual(self.cache['two'], '2.5') + self.assertEqual(self.cache._metrics.hits, 4) + + # non-existent-item tests + self.assertEqual(self.cache.get('four', '4'), '4') + self.assertIs(self.cache.get('four', None), None) + + with self.assertRaises(KeyError): + self.cache['four'] + + with self.assertRaises(KeyError): + self.cache.get('four') + + with self.assertRaises(KeyError): + self.cache.get_with_expiry('four') + + self.assertEqual(self.cache._metrics.hits, 4) + self.assertEqual(self.cache._metrics.misses, 5) + + def test_expiry(self): + self.cache.set('one', '1', 10) + self.cache.set('two', '2', 20) + self.cache.set('three', '3', 30) + + self.assertEqual(len(self.cache), 3) + self.assertEqual(self.cache['one'], '1') + self.assertEqual(self.cache['two'], '2') + + # enough for the first entry to expire, but not the rest + self.mock_timer.side_effect = lambda: 110.0 + + self.assertEqual(len(self.cache), 2) + self.assertFalse('one' in self.cache) + self.assertEqual(self.cache['two'], '2') + self.assertEqual(self.cache['three'], '3') + + self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120)) + + self.assertEqual(self.cache._metrics.hits, 5) + self.assertEqual(self.cache._metrics.misses, 0) |