diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py
index 0c23068bcf..b5ad99348d 100644
--- a/tests/config/test_generate.py
+++ b/tests/config/test_generate.py
@@ -51,7 +51,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
"lemurs.win.log.config",
"lemurs.win.signing.key",
"lemurs.win.tls.crt",
- "lemurs.win.tls.dh",
"lemurs.win.tls.key",
]
),
diff --git a/tests/http/federation/__init__.py b/tests/http/federation/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/tests/http/federation/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
new file mode 100644
index 0000000000..7a3881f558
--- /dev/null
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -0,0 +1,240 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+import treq
+
+from twisted.internet import defer
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.test.ssl_helpers import ServerTLSContext
+from twisted.web.http import HTTPChannel
+
+from synapse.crypto.context_factory import ClientTLSOptionsFactory
+from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.util.logcontext import LoggingContext
+
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+
+class MatrixFederationAgentTests(TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ self.mock_resolver = Mock()
+
+ self.agent = MatrixFederationAgent(
+ reactor=self.reactor,
+ tls_client_options_factory=ClientTLSOptionsFactory(None),
+ _srv_resolver=self.mock_resolver,
+ )
+
+ def _make_connection(self, client_factory, expected_sni):
+ """Builds a test server, and completes the outgoing client connection
+
+ Returns:
+ HTTPChannel: the test server
+ """
+
+ # build the test server
+ server_tls_protocol = _build_test_server()
+
+ # now, tell the client protocol factory to build the client protocol (it will be a
+ # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
+ # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
+ # a FakeTransport.
+ #
+ # Normally this would be done by the TCP socket code in Twisted, but we are
+ # stubbing that out here.
+ client_protocol = client_factory.buildProtocol(None)
+ client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor))
+
+ # tell the server tls protocol to send its stuff back to the client, too
+ server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
+
+ # give the reactor a pump to get the TLS juices flowing.
+ self.reactor.pump((0.1,))
+
+ # check the SNI
+ server_name = server_tls_protocol._tlsConnection.get_servername()
+ self.assertEqual(
+ server_name,
+ expected_sni,
+ "Expected SNI %s but got %s" % (expected_sni, server_name),
+ )
+
+ # fish the test server back out of the server-side TLS protocol.
+ return server_tls_protocol.wrappedProtocol
+
+ @defer.inlineCallbacks
+ def _make_get_request(self, uri):
+ """
+ Sends a simple GET request via the agent, and checks its logcontext management
+ """
+ with LoggingContext("one") as context:
+ fetch_d = self.agent.request(b'GET', uri)
+
+ # Nothing happened yet
+ self.assertNoResult(fetch_d)
+
+ # should have reset logcontext to the sentinel
+ _check_logcontext(LoggingContext.sentinel)
+
+ try:
+ fetch_res = yield fetch_d
+ defer.returnValue(fetch_res)
+ finally:
+ _check_logcontext(context)
+
+ def test_get(self):
+ """
+ happy-path test of a GET request
+ """
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+ test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=b"testserv",
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ self.assertEqual(
+ request.requestHeaders.getRawHeaders(b'host'),
+ [b'testserv:8448']
+ )
+ content = request.content.read()
+ self.assertEqual(content, b'')
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # send the headers
+ request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json'])
+ request.write('')
+
+ self.reactor.pump((0.1,))
+
+ response = self.successResultOf(test_d)
+
+ # that should give us a Response object
+ self.assertEqual(response.code, 200)
+
+ # Send the body
+ request.write('{ "a": 1 }'.encode('ascii'))
+ request.finish()
+
+ self.reactor.pump((0.1,))
+
+ # check it can be read
+ json = self.successResultOf(treq.json_content(response))
+ self.assertEqual(json, {"a": 1})
+
+ def test_get_ip_address(self):
+ """
+ Test the behaviour when the server name contains an explicit IP (with no port)
+ """
+
+ # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
+ self.mock_resolver.resolve_service.side_effect = lambda _: []
+
+ # then there will be a getaddrinfo on the IP
+ self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
+
+ test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ self.mock_resolver.resolve_service.assert_called_once()
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8448)
+
+ # make a test server, and wire up the client
+ http_server = self._make_connection(
+ client_factory,
+ expected_sni=None,
+ )
+
+ self.assertEqual(len(http_server.requests), 1)
+ request = http_server.requests[0]
+ self.assertEqual(request.method, b'GET')
+ self.assertEqual(request.path, b'/foo/bar')
+ # XXX currently broken
+ # self.assertEqual(
+ # request.requestHeaders.getRawHeaders(b'host'),
+ # [b'1.2.3.4:8448']
+ # )
+
+ # finish the request
+ request.finish()
+ self.reactor.pump((0.1,))
+ self.successResultOf(test_d)
+
+
+def _check_logcontext(context):
+ current = LoggingContext.current_context()
+ if current is not context:
+ raise AssertionError(
+ "Expected logcontext %s but was %s" % (context, current),
+ )
+
+
+def _build_test_server():
+ """Construct a test server
+
+ This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
+
+ Returns:
+ TLSMemoryBIOProtocol
+ """
+ server_factory = Factory.forProtocol(HTTPChannel)
+ # Request.finish expects the factory to have a 'log' method.
+ server_factory.log = _log_request
+
+ server_tls_factory = TLSMemoryBIOFactory(
+ ServerTLSContext(), isClient=False, wrappedFactory=server_factory,
+ )
+
+ return server_tls_factory.buildProtocol(None)
+
+
+def _log_request(request):
+ """Implements Factory.log, which is expected by Request.finish"""
+ logger.info("Completed request %s", request)
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
new file mode 100644
index 0000000000..a872e2441e
--- /dev/null
+++ b/tests/http/federation/test_srv_resolver.py
@@ -0,0 +1,207 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.internet.error import ConnectError
+from twisted.names import dns, error
+
+from synapse.http.federation.srv_resolver import SrvResolver
+from synapse.util.logcontext import LoggingContext
+
+from tests import unittest
+from tests.utils import MockClock
+
+
+class SrvResolverTestCase(unittest.TestCase):
+ def test_resolve(self):
+ dns_client_mock = Mock()
+
+ service_name = b"test_service.example.com"
+ host_name = b"example.com"
+
+ answer_srv = dns.RRHeader(
+ type=dns.SRV, payload=dns.Record_SRV(target=host_name)
+ )
+
+ result_deferred = Deferred()
+ dns_client_mock.lookupService.return_value = result_deferred
+
+ cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
+
+ @defer.inlineCallbacks
+ def do_lookup():
+
+ with LoggingContext("one") as ctx:
+ resolve_d = resolver.resolve_service(service_name)
+
+ self.assertNoResult(resolve_d)
+
+ # should have reset to the sentinel context
+ self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+
+ result = yield resolve_d
+
+ # should have restored our context
+ self.assertIs(LoggingContext.current_context(), ctx)
+
+ defer.returnValue(result)
+
+ test_d = do_lookup()
+ self.assertNoResult(test_d)
+
+ dns_client_mock.lookupService.assert_called_once_with(service_name)
+
+ result_deferred.callback(
+ ([answer_srv], None, None)
+ )
+
+ servers = self.successResultOf(test_d)
+
+ self.assertEquals(len(servers), 1)
+ self.assertEquals(servers, cache[service_name])
+ self.assertEquals(servers[0].host, host_name)
+
+ @defer.inlineCallbacks
+ def test_from_cache_expired_and_dns_fail(self):
+ dns_client_mock = Mock()
+ dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
+
+ service_name = b"test_service.example.com"
+
+ entry = Mock(spec_set=["expires"])
+ entry.expires = 0
+
+ cache = {service_name: [entry]}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
+
+ servers = yield resolver.resolve_service(service_name)
+
+ dns_client_mock.lookupService.assert_called_once_with(service_name)
+
+ self.assertEquals(len(servers), 1)
+ self.assertEquals(servers, cache[service_name])
+
+ @defer.inlineCallbacks
+ def test_from_cache(self):
+ clock = MockClock()
+
+ dns_client_mock = Mock(spec_set=['lookupService'])
+ dns_client_mock.lookupService = Mock(spec_set=[])
+
+ service_name = b"test_service.example.com"
+
+ entry = Mock(spec_set=["expires"])
+ entry.expires = 999999999
+
+ cache = {service_name: [entry]}
+ resolver = SrvResolver(
+ dns_client=dns_client_mock, cache=cache, get_time=clock.time,
+ )
+
+ servers = yield resolver.resolve_service(service_name)
+
+ self.assertFalse(dns_client_mock.lookupService.called)
+
+ self.assertEquals(len(servers), 1)
+ self.assertEquals(servers, cache[service_name])
+
+ @defer.inlineCallbacks
+ def test_empty_cache(self):
+ dns_client_mock = Mock()
+
+ dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
+
+ service_name = b"test_service.example.com"
+
+ cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
+
+ with self.assertRaises(error.DNSServerError):
+ yield resolver.resolve_service(service_name)
+
+ @defer.inlineCallbacks
+ def test_name_error(self):
+ dns_client_mock = Mock()
+
+ dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
+
+ service_name = b"test_service.example.com"
+
+ cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
+
+ servers = yield resolver.resolve_service(service_name)
+
+ self.assertEquals(len(servers), 0)
+ self.assertEquals(len(cache), 0)
+
+ def test_disabled_service(self):
+ """
+ test the behaviour when there is a single record which is ".".
+ """
+ service_name = b"test_service.example.com"
+
+ lookup_deferred = Deferred()
+ dns_client_mock = Mock()
+ dns_client_mock.lookupService.return_value = lookup_deferred
+ cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
+
+ resolve_d = resolver.resolve_service(service_name)
+ self.assertNoResult(resolve_d)
+
+ # returning a single "." should make the lookup fail with a ConenctError
+ lookup_deferred.callback((
+ [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))],
+ None,
+ None,
+ ))
+
+ self.failureResultOf(resolve_d, ConnectError)
+
+ def test_non_srv_answer(self):
+ """
+ test the behaviour when the dns server gives us a spurious non-SRV response
+ """
+ service_name = b"test_service.example.com"
+
+ lookup_deferred = Deferred()
+ dns_client_mock = Mock()
+ dns_client_mock.lookupService.return_value = lookup_deferred
+ cache = {}
+ resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
+
+ resolve_d = resolver.resolve_service(service_name)
+ self.assertNoResult(resolve_d)
+
+ lookup_deferred.callback((
+ [
+ dns.RRHeader(type=dns.A, payload=dns.Record_A()),
+ dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")),
+ ],
+ None,
+ None,
+ ))
+
+ servers = self.successResultOf(resolve_d)
+
+ self.assertEquals(len(servers), 1)
+ self.assertEquals(servers, cache[service_name])
+ self.assertEquals(servers[0].host, b"host")
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b2e38276d8..d37f8f9981 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -15,8 +15,10 @@
from mock import Mock
+from twisted.internet import defer
from twisted.internet.defer import TimeoutError
from twisted.internet.error import ConnectingCancelledError, DNSLookupError
+from twisted.test.proto_helpers import StringTransport
from twisted.web.client import ResponseNeverReceived
from twisted.web.http import HTTPChannel
@@ -25,11 +27,20 @@ from synapse.http.matrixfederationclient import (
MatrixFederationHttpClient,
MatrixFederationRequest,
)
+from synapse.util.logcontext import LoggingContext
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase
+def check_logcontext(context):
+ current = LoggingContext.current_context()
+ if current is not context:
+ raise AssertionError(
+ "Expected logcontext %s but was %s" % (context, current),
+ )
+
+
class FederationClientTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
@@ -42,9 +53,73 @@ class FederationClientTests(HomeserverTestCase):
self.cl = MatrixFederationHttpClient(self.hs)
self.reactor.lookups["testserv"] = "1.2.3.4"
+ def test_client_get(self):
+ """
+ happy-path test of a GET request
+ """
+ @defer.inlineCallbacks
+ def do_request():
+ with LoggingContext("one") as context:
+ fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+
+ # Nothing happened yet
+ self.assertNoResult(fetch_d)
+
+ # should have reset logcontext to the sentinel
+ check_logcontext(LoggingContext.sentinel)
+
+ try:
+ fetch_res = yield fetch_d
+ defer.returnValue(fetch_res)
+ finally:
+ check_logcontext(context)
+
+ test_d = do_request()
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(test_d)
+
+ # Make sure treq is trying to connect
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8008)
+
+ # complete the connection and wire it up to a fake transport
+ protocol = factory.buildProtocol(None)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ # that should have made it send the request to the transport
+ self.assertRegex(transport.value(), b"^GET /foo/bar")
+
+ # Deferred is still without a result
+ self.assertNoResult(test_d)
+
+ # Send it the HTTP response
+ res_json = '{ "a": 1 }'.encode('ascii')
+ protocol.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Server: Fake\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: %i\r\n"
+ b"\r\n"
+ b"%s" % (len(res_json), res_json)
+ )
+
+ self.pump()
+
+ res = self.successResultOf(test_d)
+
+ # check the response is as expected
+ self.assertEqual(res, {"a": 1})
+
def test_dns_error(self):
"""
- If the DNS raising returns an error, it will bubble up.
+ If the DNS lookup returns an error, it will bubble up.
"""
d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
self.pump()
@@ -53,6 +128,28 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed)
self.assertIsInstance(f.value.inner_exception, DNSLookupError)
+ def test_client_connection_refused(self):
+ d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+
+ self.pump()
+
+ # Nothing happened yet
+ self.assertNoResult(d)
+
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(port, 8008)
+ e = Exception("go away")
+ factory.clientConnectionFailed(None, e)
+ self.pump(0.5)
+
+ f = self.failureResultOf(d)
+
+ self.assertIsInstance(f.value, RequestSendFailed)
+ self.assertIs(f.value.inner_exception, e)
+
def test_client_never_connect(self):
"""
If the HTTP request is not connected and is timed out, it'll give a
@@ -63,7 +160,7 @@ class FederationClientTests(HomeserverTestCase):
self.pump()
# Nothing happened yet
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
@@ -72,7 +169,7 @@ class FederationClientTests(HomeserverTestCase):
self.assertEqual(clients[0][1], 8008)
# Deferred is still without a result
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Push by enough to time it out
self.reactor.advance(10.5)
@@ -94,7 +191,7 @@ class FederationClientTests(HomeserverTestCase):
self.pump()
# Nothing happened yet
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
@@ -107,7 +204,7 @@ class FederationClientTests(HomeserverTestCase):
client.makeConnection(conn)
# Deferred is still without a result
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Push by enough to time it out
self.reactor.advance(10.5)
@@ -135,7 +232,7 @@ class FederationClientTests(HomeserverTestCase):
client.makeConnection(conn)
# Deferred does not have a result
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Send it the HTTP response
client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
@@ -159,7 +256,7 @@ class FederationClientTests(HomeserverTestCase):
client.makeConnection(conn)
# Deferred does not have a result
- self.assertFalse(d.called)
+ self.assertNoResult(d)
# Send it the HTTP response
client.dataReceived(
@@ -195,3 +292,42 @@ class FederationClientTests(HomeserverTestCase):
request = server.requests[0]
content = request.content.read()
self.assertEqual(content, b'{"a":"b"}')
+
+ def test_closes_connection(self):
+ """Check that the client closes unused HTTP connections"""
+ d = self.cl.get_json("testserv:8008", "foo/bar")
+
+ self.pump()
+
+ # there should have been a call to connectTCP
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (_host, _port, factory, _timeout, _bindAddress) = clients[0]
+
+ # complete the connection and wire it up to a fake transport
+ client = factory.buildProtocol(None)
+ conn = StringTransport()
+ client.makeConnection(conn)
+
+ # that should have made it send the request to the connection
+ self.assertRegex(conn.value(), b"^GET /foo/bar")
+
+ # Send the HTTP response
+ client.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Content-Type: application/json\r\n"
+ b"Content-Length: 2\r\n"
+ b"\r\n"
+ b"{}"
+ )
+
+ # We should get a successful response
+ r = self.successResultOf(d)
+ self.assertEqual(r, {})
+
+ self.assertFalse(conn.disconnecting)
+
+ # wait for a while
+ self.pump(120)
+
+ self.assertTrue(conn.disconnecting)
diff --git a/tests/server.py b/tests/server.py
index db43fa0db8..ed2a046ae6 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,4 +1,5 @@
import json
+import logging
from io import BytesIO
from six import text_type
@@ -22,6 +23,8 @@ from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
+logger = logging.getLogger(__name__)
+
class TimedOutException(Exception):
"""
@@ -339,7 +342,7 @@ def get_clock():
return (clock, hs_clock)
-@attr.s
+@attr.s(cmp=False)
class FakeTransport(object):
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -414,6 +417,11 @@ class FakeTransport(object):
self.buffer = self.buffer + byt
def _write():
+ if not self.buffer:
+ # nothing to do. Don't write empty buffers: it upsets the
+ # TLSMemoryBIOProtocol
+ return
+
if getattr(self.other, "transport") is not None:
self.other.dataReceived(self.buffer)
self.buffer = b""
@@ -421,7 +429,10 @@ class FakeTransport(object):
self._reactor.callLater(0.0, _write)
- _write()
+ # always actually do the write asynchronously. Some protocols (notably the
+ # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
+ # still doing a write. Doing a callLater here breaks the cycle.
+ self._reactor.callLater(0.0, _write)
def writeSequence(self, seq):
for x in seq:
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 829f47d2e8..452d76ddd5 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.db_pool.runWithConnection = runWithConnection
config = Mock()
+ config._enable_native_upserts = False
config.event_cache_size = 1
config.database_config = {"name": "sqlite3"}
hs = TestHomeServer(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 4577e9422b..858efe4992 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -62,6 +62,77 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
+ def test_insert_new_client_ip_none_device_id(self):
+ """
+ An insert with a device ID of NULL will not create a new entry, but
+ update an existing entry in the user_ips table.
+ """
+ self.reactor.advance(12345678)
+
+ user_id = "@user:id"
+
+ # Add & trigger the storage loop
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", None
+ )
+ )
+ self.reactor.advance(200)
+ self.pump(0)
+
+ result = self.get_success(
+ self.store._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+ desc="get_user_ip_and_agents",
+ )
+ )
+
+ self.assertEqual(
+ result,
+ [
+ {
+ 'access_token': 'access_token',
+ 'ip': 'ip',
+ 'user_agent': 'user_agent',
+ 'device_id': None,
+ 'last_seen': 12345678000,
+ }
+ ],
+ )
+
+ # Add another & trigger the storage loop
+ self.get_success(
+ self.store.insert_client_ip(
+ user_id, "access_token", "ip", "user_agent", None
+ )
+ )
+ self.reactor.advance(10)
+ self.pump(0)
+
+ result = self.get_success(
+ self.store._simple_select_list(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+ desc="get_user_ip_and_agents",
+ )
+ )
+ # Only one result, has been upserted.
+ self.assertEqual(
+ result,
+ [
+ {
+ 'access_token': 'access_token',
+ 'ip': 'ip',
+ 'user_agent': 'user_agent',
+ 'device_id': None,
+ 'last_seen': 12345878000,
+ }
+ ],
+ )
+
def test_disabled_monthly_active_user(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
diff --git a/tests/test_dns.py b/tests/test_dns.py
deleted file mode 100644
index 90bd34be34..0000000000
--- a/tests/test_dns.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from mock import Mock
-
-from twisted.internet import defer
-from twisted.names import dns, error
-
-from synapse.http.endpoint import resolve_service
-
-from tests.utils import MockClock
-
-from . import unittest
-
-
-@unittest.DEBUG
-class DnsTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def test_resolve(self):
- dns_client_mock = Mock()
-
- service_name = "test_service.example.com"
- host_name = "example.com"
-
- answer_srv = dns.RRHeader(
- type=dns.SRV, payload=dns.Record_SRV(target=host_name)
- )
-
- dns_client_mock.lookupService.return_value = defer.succeed(
- ([answer_srv], None, None)
- )
-
- cache = {}
-
- servers = yield resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache
- )
-
- dns_client_mock.lookupService.assert_called_once_with(service_name)
-
- self.assertEquals(len(servers), 1)
- self.assertEquals(servers, cache[service_name])
- self.assertEquals(servers[0].host, host_name)
-
- @defer.inlineCallbacks
- def test_from_cache_expired_and_dns_fail(self):
- dns_client_mock = Mock()
- dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
-
- service_name = "test_service.example.com"
-
- entry = Mock(spec_set=["expires"])
- entry.expires = 0
-
- cache = {service_name: [entry]}
-
- servers = yield resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache
- )
-
- dns_client_mock.lookupService.assert_called_once_with(service_name)
-
- self.assertEquals(len(servers), 1)
- self.assertEquals(servers, cache[service_name])
-
- @defer.inlineCallbacks
- def test_from_cache(self):
- clock = MockClock()
-
- dns_client_mock = Mock(spec_set=['lookupService'])
- dns_client_mock.lookupService = Mock(spec_set=[])
-
- service_name = "test_service.example.com"
-
- entry = Mock(spec_set=["expires"])
- entry.expires = 999999999
-
- cache = {service_name: [entry]}
-
- servers = yield resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache, clock=clock
- )
-
- self.assertFalse(dns_client_mock.lookupService.called)
-
- self.assertEquals(len(servers), 1)
- self.assertEquals(servers, cache[service_name])
-
- @defer.inlineCallbacks
- def test_empty_cache(self):
- dns_client_mock = Mock()
-
- dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
-
- service_name = "test_service.example.com"
-
- cache = {}
-
- with self.assertRaises(error.DNSServerError):
- yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
-
- @defer.inlineCallbacks
- def test_name_error(self):
- dns_client_mock = Mock()
-
- dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
-
- service_name = "test_service.example.com"
-
- cache = {}
-
- servers = yield resolve_service(
- service_name, dns_client=dns_client_mock, cache=cache
- )
-
- self.assertEquals(len(servers), 0)
- self.assertEquals(len(cache), 0)
diff --git a/tests/test_server.py b/tests/test_server.py
index 634a8fbca5..08fb3fe02f 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -19,7 +19,7 @@ from six import StringIO
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -30,12 +30,18 @@ from synapse.util import Clock
from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest
-from tests.server import FakeTransport, make_request, render, setup_test_homeserver
+from tests.server import (
+ FakeTransport,
+ ThreadedMemoryReactorClock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
class JsonResourceTests(unittest.TestCase):
def setUp(self):
- self.reactor = MemoryReactorClock()
+ self.reactor = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.reactor)
self.homeserver = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
diff --git a/tests/unittest.py b/tests/unittest.py
index 78d2f740f9..cda549c783 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -96,7 +96,7 @@ class TestCase(unittest.TestCase):
method = getattr(self, methodName)
- level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
+ level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
@around(self)
def setUp(orig):
@@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase):
"""
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
- return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ stor = hs.get_datastore()
+
+ # Run the database background updates.
+ if hasattr(stor, "do_next_background_update"):
+ while not self.get_success(stor.has_completed_background_updates()):
+ self.get_success(stor.do_next_background_update(1))
+
+ return hs
def pump(self, by=0.0):
"""
diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py
new file mode 100644
index 0000000000..84dd71e47a
--- /dev/null
+++ b/tests/util/test_async_utils.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.task import Clock
+
+from synapse.util import logcontext
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.logcontext import LoggingContext
+
+from tests.unittest import TestCase
+
+
+class TimeoutDeferredTest(TestCase):
+ def setUp(self):
+ self.clock = Clock()
+
+ def test_times_out(self):
+ """Basic test case that checks that the original deferred is cancelled and that
+ the timing-out deferred is errbacked
+ """
+ cancelled = [False]
+
+ def canceller(_d):
+ cancelled[0] = True
+
+ non_completing_d = Deferred(canceller)
+ timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
+
+ self.assertNoResult(timing_out_d)
+ self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
+
+ self.clock.pump((1.0, ))
+
+ self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
+ self.failureResultOf(timing_out_d, defer.TimeoutError, )
+
+ def test_times_out_when_canceller_throws(self):
+ """Test that we have successfully worked around
+ https://twistedmatrix.com/trac/ticket/9534"""
+
+ def canceller(_d):
+ raise Exception("can't cancel this deferred")
+
+ non_completing_d = Deferred(canceller)
+ timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
+
+ self.assertNoResult(timing_out_d)
+
+ self.clock.pump((1.0, ))
+
+ self.failureResultOf(timing_out_d, defer.TimeoutError, )
+
+ def test_logcontext_is_preserved_on_cancellation(self):
+ blocking_was_cancelled = [False]
+
+ @defer.inlineCallbacks
+ def blocking():
+ non_completing_d = Deferred()
+ with logcontext.PreserveLoggingContext():
+ try:
+ yield non_completing_d
+ except CancelledError:
+ blocking_was_cancelled[0] = True
+ raise
+
+ with logcontext.LoggingContext("one") as context_one:
+ # the errbacks should be run in the test logcontext
+ def errback(res, deferred_name):
+ self.assertIs(
+ LoggingContext.current_context(), context_one,
+ "errback %s run in unexpected logcontext %s" % (
+ deferred_name, LoggingContext.current_context(),
+ )
+ )
+ return res
+
+ original_deferred = blocking()
+ original_deferred.addErrback(errback, "orig")
+ timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
+ self.assertNoResult(timing_out_d)
+ self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel)
+ timing_out_d.addErrback(errback, "timingout")
+
+ self.clock.pump((1.0, ))
+
+ self.assertTrue(
+ blocking_was_cancelled[0],
+ "non-completing deferred was not cancelled",
+ )
+ self.failureResultOf(timing_out_d, defer.TimeoutError, )
+ self.assertIs(LoggingContext.current_context(), context_one)
|