summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2019-01-24 14:51:35 +0000
committerAndrew Morgan <andrew@amorgan.xyz>2019-01-24 14:51:35 +0000
commitacaca1b4e97b062ab0f926794435b901b7a3fa4e (patch)
tree79c9381f3e82d0c4ded96485988adb6d9016a910 /tests
parentFix missing synapse metrics import (diff)
parentisort (diff)
downloadsynapse-acaca1b4e97b062ab0f926794435b901b7a3fa4e.tar.xz
Merge branch 'anoa/room_dir_quick_fix' into matrix-org-hotfixes
Diffstat (limited to 'tests')
-rw-r--r--tests/config/test_generate.py1
-rw-r--r--tests/http/federation/__init__.py14
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py240
-rw-r--r--tests/http/federation/test_srv_resolver.py207
-rw-r--r--tests/http/test_fedclient.py150
-rw-r--r--tests/server.py15
-rw-r--r--tests/storage/test_base.py1
-rw-r--r--tests/storage/test_client_ips.py71
-rw-r--r--tests/test_dns.py129
-rw-r--r--tests/test_server.py12
-rw-r--r--tests/unittest.py12
-rw-r--r--tests/util/test_async_utils.py104
12 files changed, 812 insertions, 144 deletions
diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py

index 0c23068bcf..b5ad99348d 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py
@@ -51,7 +51,6 @@ class ConfigGenerationTestCase(unittest.TestCase): "lemurs.win.log.config", "lemurs.win.signing.key", "lemurs.win.tls.crt", - "lemurs.win.tls.dh", "lemurs.win.tls.key", ] ), diff --git a/tests/http/federation/__init__.py b/tests/http/federation/__init__.py new file mode 100644
index 0000000000..1453d04571 --- /dev/null +++ b/tests/http/federation/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py new file mode 100644
index 0000000000..7a3881f558 --- /dev/null +++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +import treq + +from twisted.internet import defer +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.test.ssl_helpers import ServerTLSContext +from twisted.web.http import HTTPChannel + +from synapse.crypto.context_factory import ClientTLSOptionsFactory +from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent +from synapse.util.logcontext import LoggingContext + +from tests.server import FakeTransport, ThreadedMemoryReactorClock +from tests.unittest import TestCase + +logger = logging.getLogger(__name__) + + +class MatrixFederationAgentTests(TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + + self.mock_resolver = Mock() + + self.agent = MatrixFederationAgent( + reactor=self.reactor, + tls_client_options_factory=ClientTLSOptionsFactory(None), + _srv_resolver=self.mock_resolver, + ) + + def _make_connection(self, client_factory, expected_sni): + """Builds a test server, and completes the outgoing client connection + + Returns: + HTTPChannel: the test server + """ + + # build the test server + server_tls_protocol = _build_test_server() + + # now, tell the client protocol factory to build the client protocol (it will be a + # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an + # HTTP11ClientProtocol) and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection(FakeTransport(server_tls_protocol, self.reactor)) + + # tell the server tls protocol to send its stuff back to the client, too + server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor)) + + # give the reactor a pump to get the TLS juices flowing. + self.reactor.pump((0.1,)) + + # check the SNI + server_name = server_tls_protocol._tlsConnection.get_servername() + self.assertEqual( + server_name, + expected_sni, + "Expected SNI %s but got %s" % (expected_sni, server_name), + ) + + # fish the test server back out of the server-side TLS protocol. + return server_tls_protocol.wrappedProtocol + + @defer.inlineCallbacks + def _make_get_request(self, uri): + """ + Sends a simple GET request via the agent, and checks its logcontext management + """ + with LoggingContext("one") as context: + fetch_d = self.agent.request(b'GET', uri) + + # Nothing happened yet + self.assertNoResult(fetch_d) + + # should have reset logcontext to the sentinel + _check_logcontext(LoggingContext.sentinel) + + try: + fetch_res = yield fetch_d + defer.returnValue(fetch_res) + finally: + _check_logcontext(context) + + def test_get(self): + """ + happy-path test of a GET request + """ + self.reactor.lookups["testserv"] = "1.2.3.4" + test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8448) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + expected_sni=b"testserv", + ) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b'GET') + self.assertEqual(request.path, b'/foo/bar') + self.assertEqual( + request.requestHeaders.getRawHeaders(b'host'), + [b'testserv:8448'] + ) + content = request.content.read() + self.assertEqual(content, b'') + + # Deferred is still without a result + self.assertNoResult(test_d) + + # send the headers + request.responseHeaders.setRawHeaders(b'Content-Type', [b'application/json']) + request.write('') + + self.reactor.pump((0.1,)) + + response = self.successResultOf(test_d) + + # that should give us a Response object + self.assertEqual(response.code, 200) + + # Send the body + request.write('{ "a": 1 }'.encode('ascii')) + request.finish() + + self.reactor.pump((0.1,)) + + # check it can be read + json = self.successResultOf(treq.json_content(response)) + self.assertEqual(json, {"a": 1}) + + def test_get_ip_address(self): + """ + Test the behaviour when the server name contains an explicit IP (with no port) + """ + + # the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?) + self.mock_resolver.resolve_service.side_effect = lambda _: [] + + # then there will be a getaddrinfo on the IP + self.reactor.lookups["1.2.3.4"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once() + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8448) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, + expected_sni=None, + ) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b'GET') + self.assertEqual(request.path, b'/foo/bar') + # XXX currently broken + # self.assertEqual( + # request.requestHeaders.getRawHeaders(b'host'), + # [b'1.2.3.4:8448'] + # ) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + + +def _check_logcontext(context): + current = LoggingContext.current_context() + if current is not context: + raise AssertionError( + "Expected logcontext %s but was %s" % (context, current), + ) + + +def _build_test_server(): + """Construct a test server + + This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol + + Returns: + TLSMemoryBIOProtocol + """ + server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + server_tls_factory = TLSMemoryBIOFactory( + ServerTLSContext(), isClient=False, wrappedFactory=server_factory, + ) + + return server_tls_factory.buildProtocol(None) + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py new file mode 100644
index 0000000000..a872e2441e --- /dev/null +++ b/tests/http/federation/test_srv_resolver.py
@@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer +from twisted.internet.defer import Deferred +from twisted.internet.error import ConnectError +from twisted.names import dns, error + +from synapse.http.federation.srv_resolver import SrvResolver +from synapse.util.logcontext import LoggingContext + +from tests import unittest +from tests.utils import MockClock + + +class SrvResolverTestCase(unittest.TestCase): + def test_resolve(self): + dns_client_mock = Mock() + + service_name = b"test_service.example.com" + host_name = b"example.com" + + answer_srv = dns.RRHeader( + type=dns.SRV, payload=dns.Record_SRV(target=host_name) + ) + + result_deferred = Deferred() + dns_client_mock.lookupService.return_value = result_deferred + + cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) + + @defer.inlineCallbacks + def do_lookup(): + + with LoggingContext("one") as ctx: + resolve_d = resolver.resolve_service(service_name) + + self.assertNoResult(resolve_d) + + # should have reset to the sentinel context + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + + result = yield resolve_d + + # should have restored our context + self.assertIs(LoggingContext.current_context(), ctx) + + defer.returnValue(result) + + test_d = do_lookup() + self.assertNoResult(test_d) + + dns_client_mock.lookupService.assert_called_once_with(service_name) + + result_deferred.callback( + ([answer_srv], None, None) + ) + + servers = self.successResultOf(test_d) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + self.assertEquals(servers[0].host, host_name) + + @defer.inlineCallbacks + def test_from_cache_expired_and_dns_fail(self): + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) + + service_name = b"test_service.example.com" + + entry = Mock(spec_set=["expires"]) + entry.expires = 0 + + cache = {service_name: [entry]} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) + + servers = yield resolver.resolve_service(service_name) + + dns_client_mock.lookupService.assert_called_once_with(service_name) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + + @defer.inlineCallbacks + def test_from_cache(self): + clock = MockClock() + + dns_client_mock = Mock(spec_set=['lookupService']) + dns_client_mock.lookupService = Mock(spec_set=[]) + + service_name = b"test_service.example.com" + + entry = Mock(spec_set=["expires"]) + entry.expires = 999999999 + + cache = {service_name: [entry]} + resolver = SrvResolver( + dns_client=dns_client_mock, cache=cache, get_time=clock.time, + ) + + servers = yield resolver.resolve_service(service_name) + + self.assertFalse(dns_client_mock.lookupService.called) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + + @defer.inlineCallbacks + def test_empty_cache(self): + dns_client_mock = Mock() + + dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) + + service_name = b"test_service.example.com" + + cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) + + with self.assertRaises(error.DNSServerError): + yield resolver.resolve_service(service_name) + + @defer.inlineCallbacks + def test_name_error(self): + dns_client_mock = Mock() + + dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) + + service_name = b"test_service.example.com" + + cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) + + servers = yield resolver.resolve_service(service_name) + + self.assertEquals(len(servers), 0) + self.assertEquals(len(cache), 0) + + def test_disabled_service(self): + """ + test the behaviour when there is a single record which is ".". + """ + service_name = b"test_service.example.com" + + lookup_deferred = Deferred() + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = lookup_deferred + cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) + + resolve_d = resolver.resolve_service(service_name) + self.assertNoResult(resolve_d) + + # returning a single "." should make the lookup fail with a ConenctError + lookup_deferred.callback(( + [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], + None, + None, + )) + + self.failureResultOf(resolve_d, ConnectError) + + def test_non_srv_answer(self): + """ + test the behaviour when the dns server gives us a spurious non-SRV response + """ + service_name = b"test_service.example.com" + + lookup_deferred = Deferred() + dns_client_mock = Mock() + dns_client_mock.lookupService.return_value = lookup_deferred + cache = {} + resolver = SrvResolver(dns_client=dns_client_mock, cache=cache) + + resolve_d = resolver.resolve_service(service_name) + self.assertNoResult(resolve_d) + + lookup_deferred.callback(( + [ + dns.RRHeader(type=dns.A, payload=dns.Record_A()), + dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), + ], + None, + None, + )) + + servers = self.successResultOf(resolve_d) + + self.assertEquals(len(servers), 1) + self.assertEquals(servers, cache[service_name]) + self.assertEquals(servers[0].host, b"host") diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b2e38276d8..d37f8f9981 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py
@@ -15,8 +15,10 @@ from mock import Mock +from twisted.internet import defer from twisted.internet.defer import TimeoutError from twisted.internet.error import ConnectingCancelledError, DNSLookupError +from twisted.test.proto_helpers import StringTransport from twisted.web.client import ResponseNeverReceived from twisted.web.http import HTTPChannel @@ -25,11 +27,20 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) +from synapse.util.logcontext import LoggingContext from tests.server import FakeTransport from tests.unittest import HomeserverTestCase +def check_logcontext(context): + current = LoggingContext.current_context() + if current is not context: + raise AssertionError( + "Expected logcontext %s but was %s" % (context, current), + ) + + class FederationClientTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): @@ -42,9 +53,73 @@ class FederationClientTests(HomeserverTestCase): self.cl = MatrixFederationHttpClient(self.hs) self.reactor.lookups["testserv"] = "1.2.3.4" + def test_client_get(self): + """ + happy-path test of a GET request + """ + @defer.inlineCallbacks + def do_request(): + with LoggingContext("one") as context: + fetch_d = self.cl.get_json("testserv:8008", "foo/bar") + + # Nothing happened yet + self.assertNoResult(fetch_d) + + # should have reset logcontext to the sentinel + check_logcontext(LoggingContext.sentinel) + + try: + fetch_res = yield fetch_d + defer.returnValue(fetch_res) + finally: + check_logcontext(context) + + test_d = do_request() + + self.pump() + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8008) + + # complete the connection and wire it up to a fake transport + protocol = factory.buildProtocol(None) + transport = StringTransport() + protocol.makeConnection(transport) + + # that should have made it send the request to the transport + self.assertRegex(transport.value(), b"^GET /foo/bar") + + # Deferred is still without a result + self.assertNoResult(test_d) + + # Send it the HTTP response + res_json = '{ "a": 1 }'.encode('ascii') + protocol.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Server: Fake\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: %i\r\n" + b"\r\n" + b"%s" % (len(res_json), res_json) + ) + + self.pump() + + res = self.successResultOf(test_d) + + # check the response is as expected + self.assertEqual(res, {"a": 1}) + def test_dns_error(self): """ - If the DNS raising returns an error, it will bubble up. + If the DNS lookup returns an error, it will bubble up. """ d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) self.pump() @@ -53,6 +128,28 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value.inner_exception, DNSLookupError) + def test_client_connection_refused(self): + d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + + self.pump() + + # Nothing happened yet + self.assertNoResult(d) + + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, '1.2.3.4') + self.assertEqual(port, 8008) + e = Exception("go away") + factory.clientConnectionFailed(None, e) + self.pump(0.5) + + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, RequestSendFailed) + self.assertIs(f.value.inner_exception, e) + def test_client_never_connect(self): """ If the HTTP request is not connected and is timed out, it'll give a @@ -63,7 +160,7 @@ class FederationClientTests(HomeserverTestCase): self.pump() # Nothing happened yet - self.assertFalse(d.called) + self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients @@ -72,7 +169,7 @@ class FederationClientTests(HomeserverTestCase): self.assertEqual(clients[0][1], 8008) # Deferred is still without a result - self.assertFalse(d.called) + self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) @@ -94,7 +191,7 @@ class FederationClientTests(HomeserverTestCase): self.pump() # Nothing happened yet - self.assertFalse(d.called) + self.assertNoResult(d) # Make sure treq is trying to connect clients = self.reactor.tcpClients @@ -107,7 +204,7 @@ class FederationClientTests(HomeserverTestCase): client.makeConnection(conn) # Deferred is still without a result - self.assertFalse(d.called) + self.assertNoResult(d) # Push by enough to time it out self.reactor.advance(10.5) @@ -135,7 +232,7 @@ class FederationClientTests(HomeserverTestCase): client.makeConnection(conn) # Deferred does not have a result - self.assertFalse(d.called) + self.assertNoResult(d) # Send it the HTTP response client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") @@ -159,7 +256,7 @@ class FederationClientTests(HomeserverTestCase): client.makeConnection(conn) # Deferred does not have a result - self.assertFalse(d.called) + self.assertNoResult(d) # Send it the HTTP response client.dataReceived( @@ -195,3 +292,42 @@ class FederationClientTests(HomeserverTestCase): request = server.requests[0] content = request.content.read() self.assertEqual(content, b'{"a":"b"}') + + def test_closes_connection(self): + """Check that the client closes unused HTTP connections""" + d = self.cl.get_json("testserv:8008", "foo/bar") + + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b"{}" + ) + + # We should get a successful response + r = self.successResultOf(d) + self.assertEqual(r, {}) + + self.assertFalse(conn.disconnecting) + + # wait for a while + self.pump(120) + + self.assertTrue(conn.disconnecting) diff --git a/tests/server.py b/tests/server.py
index db43fa0db8..ed2a046ae6 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -1,4 +1,5 @@ import json +import logging from io import BytesIO from six import text_type @@ -22,6 +23,8 @@ from synapse.util import Clock from tests.utils import setup_test_homeserver as _sth +logger = logging.getLogger(__name__) + class TimedOutException(Exception): """ @@ -339,7 +342,7 @@ def get_clock(): return (clock, hs_clock) -@attr.s +@attr.s(cmp=False) class FakeTransport(object): """ A twisted.internet.interfaces.ITransport implementation which sends all its data @@ -414,6 +417,11 @@ class FakeTransport(object): self.buffer = self.buffer + byt def _write(): + if not self.buffer: + # nothing to do. Don't write empty buffers: it upsets the + # TLSMemoryBIOProtocol + return + if getattr(self.other, "transport") is not None: self.other.dataReceived(self.buffer) self.buffer = b"" @@ -421,7 +429,10 @@ class FakeTransport(object): self._reactor.callLater(0.0, _write) - _write() + # always actually do the write asynchronously. Some protocols (notably the + # TLSMemoryBIOProtocol) get very confused if a read comes back while they are + # still doing a write. Doing a callLater here breaks the cycle. + self._reactor.callLater(0.0, _write) def writeSequence(self, seq): for x in seq: diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 829f47d2e8..452d76ddd5 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py
@@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.db_pool.runWithConnection = runWithConnection config = Mock() + config._enable_native_upserts = False config.event_cache_size = 1 config.database_config = {"name": "sqlite3"} hs = TestHomeServer( diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 4577e9422b..858efe4992 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -62,6 +62,77 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): r, ) + def test_insert_new_client_ip_none_device_id(self): + """ + An insert with a device ID of NULL will not create a new entry, but + update an existing entry in the user_ips table. + """ + self.reactor.advance(12345678) + + user_id = "@user:id" + + # Add & trigger the storage loop + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", None + ) + ) + self.reactor.advance(200) + self.pump(0) + + result = self.get_success( + self.store._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + self.assertEqual( + result, + [ + { + 'access_token': 'access_token', + 'ip': 'ip', + 'user_agent': 'user_agent', + 'device_id': None, + 'last_seen': 12345678000, + } + ], + ) + + # Add another & trigger the storage loop + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", None + ) + ) + self.reactor.advance(10) + self.pump(0) + + result = self.get_success( + self.store._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + # Only one result, has been upserted. + self.assertEqual( + result, + [ + { + 'access_token': 'access_token', + 'ip': 'ip', + 'user_agent': 'user_agent', + 'device_id': None, + 'last_seen': 12345878000, + } + ], + ) + def test_disabled_monthly_active_user(self): self.hs.config.limit_usage_by_mau = False self.hs.config.max_mau_value = 50 diff --git a/tests/test_dns.py b/tests/test_dns.py deleted file mode 100644
index 90bd34be34..0000000000 --- a/tests/test_dns.py +++ /dev/null
@@ -1,129 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from mock import Mock - -from twisted.internet import defer -from twisted.names import dns, error - -from synapse.http.endpoint import resolve_service - -from tests.utils import MockClock - -from . import unittest - - -@unittest.DEBUG -class DnsTestCase(unittest.TestCase): - @defer.inlineCallbacks - def test_resolve(self): - dns_client_mock = Mock() - - service_name = "test_service.example.com" - host_name = "example.com" - - answer_srv = dns.RRHeader( - type=dns.SRV, payload=dns.Record_SRV(target=host_name) - ) - - dns_client_mock.lookupService.return_value = defer.succeed( - ([answer_srv], None, None) - ) - - cache = {} - - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) - - dns_client_mock.lookupService.assert_called_once_with(service_name) - - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - self.assertEquals(servers[0].host, host_name) - - @defer.inlineCallbacks - def test_from_cache_expired_and_dns_fail(self): - dns_client_mock = Mock() - dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - - service_name = "test_service.example.com" - - entry = Mock(spec_set=["expires"]) - entry.expires = 0 - - cache = {service_name: [entry]} - - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) - - dns_client_mock.lookupService.assert_called_once_with(service_name) - - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - - @defer.inlineCallbacks - def test_from_cache(self): - clock = MockClock() - - dns_client_mock = Mock(spec_set=['lookupService']) - dns_client_mock.lookupService = Mock(spec_set=[]) - - service_name = "test_service.example.com" - - entry = Mock(spec_set=["expires"]) - entry.expires = 999999999 - - cache = {service_name: [entry]} - - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache, clock=clock - ) - - self.assertFalse(dns_client_mock.lookupService.called) - - self.assertEquals(len(servers), 1) - self.assertEquals(servers, cache[service_name]) - - @defer.inlineCallbacks - def test_empty_cache(self): - dns_client_mock = Mock() - - dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - - service_name = "test_service.example.com" - - cache = {} - - with self.assertRaises(error.DNSServerError): - yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache) - - @defer.inlineCallbacks - def test_name_error(self): - dns_client_mock = Mock() - - dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) - - service_name = "test_service.example.com" - - cache = {} - - servers = yield resolve_service( - service_name, dns_client=dns_client_mock, cache=cache - ) - - self.assertEquals(len(servers), 0) - self.assertEquals(len(cache), 0) diff --git a/tests/test_server.py b/tests/test_server.py
index 634a8fbca5..08fb3fe02f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py
@@ -19,7 +19,7 @@ from six import StringIO from twisted.internet.defer import Deferred from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET @@ -30,12 +30,18 @@ from synapse.util import Clock from synapse.util.logcontext import make_deferred_yieldable from tests import unittest -from tests.server import FakeTransport, make_request, render, setup_test_homeserver +from tests.server import ( + FakeTransport, + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) class JsonResourceTests(unittest.TestCase): def setUp(self): - self.reactor = MemoryReactorClock() + self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor diff --git a/tests/unittest.py b/tests/unittest.py
index 78d2f740f9..cda549c783 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -96,7 +96,7 @@ class TestCase(unittest.TestCase): method = getattr(self, methodName) - level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR)) + level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING)) @around(self) def setUp(orig): @@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase): """ kwargs = dict(kwargs) kwargs.update(self._hs_args) - return setup_test_homeserver(self.addCleanup, *args, **kwargs) + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) + stor = hs.get_datastore() + + # Run the database background updates. + if hasattr(stor, "do_next_background_update"): + while not self.get_success(stor.has_completed_background_updates()): + self.get_success(stor.do_next_background_update(1)) + + return hs def pump(self, by=0.0): """ diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py new file mode 100644
index 0000000000..84dd71e47a --- /dev/null +++ b/tests/util/test_async_utils.py
@@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.task import Clock + +from synapse.util import logcontext +from synapse.util.async_helpers import timeout_deferred +from synapse.util.logcontext import LoggingContext + +from tests.unittest import TestCase + + +class TimeoutDeferredTest(TestCase): + def setUp(self): + self.clock = Clock() + + def test_times_out(self): + """Basic test case that checks that the original deferred is cancelled and that + the timing-out deferred is errbacked + """ + cancelled = [False] + + def canceller(_d): + cancelled[0] = True + + non_completing_d = Deferred(canceller) + timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + + self.assertNoResult(timing_out_d) + self.assertFalse(cancelled[0], "deferred was cancelled prematurely") + + self.clock.pump((1.0, )) + + self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") + self.failureResultOf(timing_out_d, defer.TimeoutError, ) + + def test_times_out_when_canceller_throws(self): + """Test that we have successfully worked around + https://twistedmatrix.com/trac/ticket/9534""" + + def canceller(_d): + raise Exception("can't cancel this deferred") + + non_completing_d = Deferred(canceller) + timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + + self.assertNoResult(timing_out_d) + + self.clock.pump((1.0, )) + + self.failureResultOf(timing_out_d, defer.TimeoutError, ) + + def test_logcontext_is_preserved_on_cancellation(self): + blocking_was_cancelled = [False] + + @defer.inlineCallbacks + def blocking(): + non_completing_d = Deferred() + with logcontext.PreserveLoggingContext(): + try: + yield non_completing_d + except CancelledError: + blocking_was_cancelled[0] = True + raise + + with logcontext.LoggingContext("one") as context_one: + # the errbacks should be run in the test logcontext + def errback(res, deferred_name): + self.assertIs( + LoggingContext.current_context(), context_one, + "errback %s run in unexpected logcontext %s" % ( + deferred_name, LoggingContext.current_context(), + ) + ) + return res + + original_deferred = blocking() + original_deferred.addErrback(errback, "orig") + timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) + self.assertNoResult(timing_out_d) + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + timing_out_d.addErrback(errback, "timingout") + + self.clock.pump((1.0, )) + + self.assertTrue( + blocking_was_cancelled[0], + "non-completing deferred was not cancelled", + ) + self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.assertIs(LoggingContext.current_context(), context_one)