diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 11ea8ef10c..fe459ea6e3 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -24,11 +24,16 @@ from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOpti
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.web.http import HTTPChannel
+from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS
from synapse.crypto.context_factory import ClientTLSOptionsFactory
-from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.http.federation.matrix_federation_agent import (
+ MatrixFederationAgent,
+ _cache_period_from_headers,
+)
from synapse.http.federation.srv_resolver import Server
+from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import LoggingContext
from tests.http import ServerTLSContext
@@ -44,11 +49,14 @@ class MatrixFederationAgentTests(TestCase):
self.mock_resolver = Mock()
+ self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+
self.agent = MatrixFederationAgent(
reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(None),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver,
+ _well_known_cache=self.well_known_cache,
)
def _make_connection(self, client_factory, expected_sni):
@@ -115,7 +123,9 @@ class MatrixFederationAgentTests(TestCase):
finally:
_check_logcontext(context)
- def _handle_well_known_connection(self, client_factory, expected_sni, target_server):
+ def _handle_well_known_connection(
+ self, client_factory, expected_sni, target_server, response_headers={},
+ ):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response.
@@ -124,6 +134,8 @@ class MatrixFederationAgentTests(TestCase):
expected_sni (bytes): SNI that we expect the outgoing connection to send
target_server (bytes): target server that we should redirect to in the
.well-known response.
+ Returns:
+ HTTPChannel: server impl
"""
# make the connection for .well-known
well_known_server = self._make_connection(
@@ -133,9 +145,10 @@ class MatrixFederationAgentTests(TestCase):
# check the .well-known request and send a response
self.assertEqual(len(well_known_server.requests), 1)
request = well_known_server.requests[0]
- self._send_well_known_response(request, target_server)
+ self._send_well_known_response(request, target_server, headers=response_headers)
+ return well_known_server
- def _send_well_known_response(self, request, target_server):
+ def _send_well_known_response(self, request, target_server, headers={}):
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
"""
@@ -146,6 +159,8 @@ class MatrixFederationAgentTests(TestCase):
[b'testserv'],
)
# send back a response
+ for k, v in headers.items():
+ request.setHeader(k, v)
request.write(b'{ "m.server": "%s" }' % (target_server,))
request.finish()
@@ -448,6 +463,13 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
+ self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
+
+ # check the cache expires
+ self.reactor.pump((25 * 3600,))
+ self.well_known_cache.expire()
+ self.assertNotIn(b"testserv", self.well_known_cache)
+
def test_get_hostname_srv(self):
"""
Test the behaviour when there is a single SRV record
@@ -661,6 +683,126 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,))
self.successResultOf(test_d)
+ @defer.inlineCallbacks
+ def do_get_well_known(self, serv):
+ try:
+ result = yield self.agent._get_well_known(serv)
+ logger.info("Result from well-known fetch: %s", result)
+ except Exception as e:
+ logger.warning("Error fetching well-known: %s", e)
+ raise
+ defer.returnValue(result)
+
+ def test_well_known_cache(self):
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ fetch_d = self.do_get_well_known(b'testserv')
+
+ # there should be an attempt to connect on port 443 for the .well-known
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ well_known_server = self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ response_headers={b'Cache-Control': b'max-age=10'},
+ target_server=b"target-server",
+ )
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r, b'target-server')
+
+ # close the tcp connection
+ well_known_server.loseConnection()
+
+ # repeat the request: it should hit the cache
+ fetch_d = self.do_get_well_known(b'testserv')
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r, b'target-server')
+
+ # expire the cache
+ self.reactor.pump((10.0,))
+
+ # now it should connect again
+ fetch_d = self.do_get_well_known(b'testserv')
+
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 443)
+
+ self._handle_well_known_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ target_server=b"other-server",
+ )
+
+ r = self.successResultOf(fetch_d)
+ self.assertEqual(r, b'other-server')
+
+
+class TestCachePeriodFromHeaders(TestCase):
+ def test_cache_control(self):
+ # uppercase
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}),
+ ), 100,
+ )
+
+ # missing value
+ self.assertIsNone(_cache_period_from_headers(
+ Headers({b'Cache-Control': [b'max-age=, bar']}),
+ ))
+
+ # hackernews: bogus due to semicolon
+ self.assertIsNone(_cache_period_from_headers(
+ Headers({b'Cache-Control': [b'private; max-age=0']}),
+ ))
+
+ # github
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}),
+ ), 0,
+ )
+
+ # google
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'cache-control': [b'private, max-age=0']}),
+ ), 0,
+ )
+
+ def test_expires(self):
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
+ time_now=lambda: 1548833700
+ ), 33,
+ )
+
+ # cache-control overrides expires
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({
+ b'cache-control': [b'max-age=10'],
+ b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']
+ }),
+ time_now=lambda: 1548833700
+ ), 10,
+ )
+
+ # invalid expires means immediate expiry
+ self.assertEqual(
+ _cache_period_from_headers(
+ Headers({b'Expires': [b'0']}),
+ ), 0,
+ )
+
def _check_logcontext(context):
current = LoggingContext.current_context()
diff --git a/tests/server.py b/tests/server.py
index 3d7ae9875c..fc1e76d146 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -360,6 +360,7 @@ class FakeTransport(object):
"""
disconnecting = False
+ disconnected = False
buffer = attr.ib(default=b'')
producer = attr.ib(default=None)
@@ -370,14 +371,16 @@ class FakeTransport(object):
return None
def loseConnection(self, reason=None):
- logger.info("FakeTransport: loseConnection(%s)", reason)
if not self.disconnecting:
+ logger.info("FakeTransport: loseConnection(%s)", reason)
self.disconnecting = True
if self._protocol:
self._protocol.connectionLost(reason)
+ self.disconnected = True
def abortConnection(self):
- self.disconnecting = True
+ logger.info("FakeTransport: abortConnection()")
+ self.loseConnection()
def pauseProducing(self):
if not self.producer:
@@ -416,9 +419,16 @@ class FakeTransport(object):
# TLSMemoryBIOProtocol
return
+ if self.disconnected:
+ return
+ logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
+
if getattr(self.other, "transport") is not None:
- self.other.dataReceived(self.buffer)
- self.buffer = b""
+ try:
+ self.other.dataReceived(self.buffer)
+ self.buffer = b""
+ except Exception as e:
+ logger.warning("Exception writing to protocol: %s", e)
return
self._reactor.callLater(0.0, _write)
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
new file mode 100644
index 0000000000..03b3c15db6
--- /dev/null
+++ b/tests/util/caches/test_ttlcache.py
@@ -0,0 +1,83 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from synapse.util.caches.ttlcache import TTLCache
+
+from tests import unittest
+
+
+class CacheTestCase(unittest.TestCase):
+ def setUp(self):
+ self.mock_timer = Mock(side_effect=lambda: 100.0)
+ self.cache = TTLCache("test_cache", self.mock_timer)
+
+ def test_get(self):
+ """simple set/get tests"""
+ self.cache.set('one', '1', 10)
+ self.cache.set('two', '2', 20)
+ self.cache.set('three', '3', 30)
+
+ self.assertEqual(len(self.cache), 3)
+
+ self.assertTrue('one' in self.cache)
+ self.assertEqual(self.cache.get('one'), '1')
+ self.assertEqual(self.cache['one'], '1')
+ self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110))
+ self.assertEqual(self.cache._metrics.hits, 3)
+ self.assertEqual(self.cache._metrics.misses, 0)
+
+ self.cache.set('two', '2.5', 20)
+ self.assertEqual(self.cache['two'], '2.5')
+ self.assertEqual(self.cache._metrics.hits, 4)
+
+ # non-existent-item tests
+ self.assertEqual(self.cache.get('four', '4'), '4')
+ self.assertIs(self.cache.get('four', None), None)
+
+ with self.assertRaises(KeyError):
+ self.cache['four']
+
+ with self.assertRaises(KeyError):
+ self.cache.get('four')
+
+ with self.assertRaises(KeyError):
+ self.cache.get_with_expiry('four')
+
+ self.assertEqual(self.cache._metrics.hits, 4)
+ self.assertEqual(self.cache._metrics.misses, 5)
+
+ def test_expiry(self):
+ self.cache.set('one', '1', 10)
+ self.cache.set('two', '2', 20)
+ self.cache.set('three', '3', 30)
+
+ self.assertEqual(len(self.cache), 3)
+ self.assertEqual(self.cache['one'], '1')
+ self.assertEqual(self.cache['two'], '2')
+
+ # enough for the first entry to expire, but not the rest
+ self.mock_timer.side_effect = lambda: 110.0
+
+ self.assertEqual(len(self.cache), 2)
+ self.assertFalse('one' in self.cache)
+ self.assertEqual(self.cache['two'], '2')
+ self.assertEqual(self.cache['three'], '3')
+
+ self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120))
+
+ self.assertEqual(self.cache._metrics.hits, 5)
+ self.assertEqual(self.cache._metrics.misses, 0)
|