summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/config/test_load.py2
-rw-r--r--tests/config/test_room_directory.py4
-rw-r--r--tests/handlers/test_register.py124
-rw-r--r--tests/handlers/test_typing.py8
-rw-r--r--tests/http/test_fedclient.py99
-rw-r--r--tests/replication/tcp/__init__.py14
-rw-r--r--tests/replication/tcp/streams/__init__.py14
-rw-r--r--tests/replication/tcp/streams/_base.py74
-rw-r--r--tests/replication/tcp/streams/test_receipts.py46
-rw-r--r--tests/rest/client/v1/test_admin.py66
-rw-r--r--tests/rest/client/v1/utils.py125
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py92
-rw-r--r--tests/unittest.py12
-rw-r--r--tests/utils.py34
14 files changed, 459 insertions, 255 deletions
diff --git a/tests/config/test_load.py b/tests/config/test_load.py

index d5f1777093..6bfc1970ad 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py
@@ -43,7 +43,7 @@ class ConfigLoadingTestCase(unittest.TestCase): self.generate_config() with open(self.file, "r") as f: - raw = yaml.load(f) + raw = yaml.safe_load(f) self.assertIn("macaroon_secret_key", raw) config = HomeServerConfig.load_config("", ["-c", self.file]) diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
index 3dc2631523..47fffcfeb2 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py
@@ -22,7 +22,7 @@ from tests import unittest class RoomDirectoryConfigTestCase(unittest.TestCase): def test_alias_creation_acl(self): - config = yaml.load(""" + config = yaml.safe_load(""" alias_creation_rules: - user_id: "*bob*" alias: "*" @@ -74,7 +74,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): )) def test_room_publish_acl(self): - config = yaml.load(""" + config = yaml.safe_load(""" alias_creation_rules: [] room_list_publication_rules: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 2217eb2a10..017ea0385e 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -22,8 +22,6 @@ from synapse.api.errors import ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester -from tests.utils import default_config, setup_test_homeserver - from .. import unittest @@ -32,26 +30,23 @@ class RegistrationHandlers(object): self.registration_handler = RegistrationHandler(hs) -class RegistrationTestCase(unittest.TestCase): +class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ - @defer.inlineCallbacks - def setUp(self): - self.mock_distributor = Mock() - self.mock_distributor.declare("registered_user") - self.mock_captcha_client = Mock() - - hs_config = default_config("test") + def make_homeserver(self, reactor, clock): + hs_config = self.default_config("test") # some of the tests rely on us having a user consent version hs_config.user_consent_version = "test_consent_version" hs_config.max_mau_value = 50 - self.hs = yield setup_test_homeserver( - self.addCleanup, - config=hs_config, - expire_access_token=True, - ) + hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + return hs + + def prepare(self, reactor, clock, hs): + self.mock_distributor = Mock() + self.mock_distributor.declare("registered_user") + self.mock_captcha_client = Mock() self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret') ) @@ -63,136 +58,133 @@ class RegistrationTestCase(unittest.TestCase): self.requester = create_requester("@requester:test") - @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): frank = UserID.from_string("@frank:test") user_id = frank.to_string() requester = create_requester(user_id) - result_user_id, result_token = yield self.handler.get_or_create_user( - requester, frank.localpart, "Frankie" + result_user_id, result_token = self.get_success( + self.handler.get_or_create_user(requester, frank.localpart, "Frankie") ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) self.assertEquals(result_token, 'secret') - @defer.inlineCallbacks def test_if_user_exists(self): store = self.hs.get_datastore() frank = UserID.from_string("@frank:test") - yield store.register( - user_id=frank.to_string(), - token="jkv;g498752-43gj['eamb!-5", - password_hash=None, + self.get_success( + store.register( + user_id=frank.to_string(), + token="jkv;g498752-43gj['eamb!-5", + password_hash=None, + ) ) local_part = frank.localpart user_id = frank.to_string() requester = create_requester(user_id) - result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, None + result_user_id, result_token = self.get_success( + self.handler.get_or_create_user(requester, local_part, None) ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) - @defer.inlineCallbacks def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.handler.get_or_create_user(self.requester, 'a', "display_name") + self.get_success( + self.handler.get_or_create_user(self.requester, 'a', "display_name") + ) - @defer.inlineCallbacks def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - yield self.handler.get_or_create_user(self.requester, 'c', "User") + self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User")) - @defer.inlineCallbacks def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user(self.requester, 'b', "display_name") + self.get_failure( + self.handler.get_or_create_user(self.requester, 'b', "display_name"), + ResourceLimitError, + ) self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user(self.requester, 'b', "display_name") + self.get_failure( + self.handler.get_or_create_user(self.requester, 'b', "display_name"), + ResourceLimitError, + ) - @defer.inlineCallbacks def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.register(localpart="local_part") + self.get_failure( + self.handler.register(localpart="local_part"), ResourceLimitError + ) self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.register(localpart="local_part") + self.get_failure( + self.handler.register(localpart="local_part"), ResourceLimitError + ) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) - room_id = yield directory_handler.get_association(room_alias) + room_id = self.get_success(directory_handler.get_association(room_alias)) self.assertTrue(room_id['room_id'] in rooms) self.assertEqual(len(rooms), 1) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms_with_no_rooms(self): self.hs.config.auto_join_rooms = [] frank = UserID.from_string("@frank:test") - res = yield self.handler.register(frank.localpart) + res = self.get_success(self.handler.register(frank.localpart)) self.assertEqual(res[0], frank.to_string()) - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_where_room_is_another_domain(self): self.hs.config.auto_join_rooms = ["#room:another"] frank = UserID.from_string("@frank:test") - res = yield self.handler.register(frank.localpart) + res = self.get_success(self.handler.register(frank.localpart)) self.assertEqual(res[0], frank.to_string()) - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_where_auto_create_is_false(self): self.hs.config.autocreate_auto_join_rooms = False room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms_when_support_user_exists(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] self.store.is_support_user = Mock(return_value=True) - res = yield self.handler.register(localpart='support') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='support')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) - with self.assertRaises(SynapseError): - yield directory_handler.get_association(room_alias) + self.get_failure(directory_handler.get_association(room_alias), SynapseError) - @defer.inlineCallbacks def test_auto_create_auto_join_where_no_consent(self): """Test to ensure that the first user is not auto-joined to a room if they have not given general consent. @@ -208,27 +200,27 @@ class RegistrationTestCase(unittest.TestCase): # (Messing with the internals of event_creation_handler is fragile # but can't see a better way to do this. One option could be to subclass # the test with custom config.) - event_creation_handler._block_events_without_consent_error = ("Error") + event_creation_handler._block_events_without_consent_error = "Error" event_creation_handler._consent_uri_builder = Mock() room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] # When:- # * the user is registered and post consent actions are called - res = yield self.handler.register(localpart='jeff') - yield self.handler.post_consent_actions(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + self.get_success(self.handler.post_consent_actions(res[0])) # Then:- # * Ensure that they have not been joined to the room - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_register_support_user(self): - res = yield self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + res = self.get_success( + self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + ) self.assertTrue(self.store.is_support_user(res[0])) - @defer.inlineCallbacks def test_register_not_support_user(self): - res = yield self.handler.register(localpart='user') + res = self.get_success(self.handler.register(localpart='user')) self.assertFalse(self.store.is_support_user(res[0])) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 13486930fb..6460cbc708 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -180,7 +180,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( "farm", - path="/_matrix/federation/v1/send/1000000/", + path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( "m.typing", content={ @@ -192,6 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): json_data_callback=ANY, long_retries=True, backoff_on_404=True, + try_trailing_slash_on_400=True, ) def test_started_typing_remote_recv(self): @@ -201,7 +202,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): (request, channel) = self.make_request( "PUT", - "/_matrix/federation/v1/send/1000000/", + "/_matrix/federation/v1/send/1000000", _make_edu_transaction_json( "m.typing", content={ @@ -257,7 +258,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( "farm", - path="/_matrix/federation/v1/send/1000000/", + path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( "m.typing", content={ @@ -269,6 +270,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): json_data_callback=ANY, long_retries=True, backoff_on_404=True, + try_trailing_slash_on_400=True, ) self.assertEquals(self.event_source.get_current_key(), 1) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b03b37affe..cd8e086f86 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py
@@ -268,6 +268,105 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, TimeoutError) + def test_client_requires_trailing_slashes(self): + """ + If a connection is made to a client but the client rejects it due to + requiring a trailing slash. We need to retry the request with a + trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. + """ + d = self.cl.get_json( + "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, + ) + + # Send the request + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Clear the original request data before sending a response + conn.clear() + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 400 Bad Request\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 59\r\n" + b"\r\n" + b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}' + ) + + # We should get another request with a trailing slash + self.assertRegex(conn.value(), b"^GET /foo/bar/") + + # Send a happy response this time + client.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b'{}' + ) + + # We should get a successful response + r = self.successResultOf(d) + self.assertEqual(r, {}) + + def test_client_does_not_retry_on_400_plus(self): + """ + Another test for trailing slashes but now test that we don't retry on + trailing slashes on a non-400/M_UNRECOGNIZED response. + + See test_client_requires_trailing_slashes() for context. + """ + d = self.cl.get_json( + "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, + ) + + # Send the request + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Clear the original request data before sending a response + conn.clear() + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 404 Not Found\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b"{}" + ) + + # We should not get another request + self.assertEqual(conn.value(), b"") + + # We should get a 404 failure response + self.failureResultOf(d) + def test_client_sends_body(self): self.cl.post_json( "testserv:8008", "foo/bar", timeout=10000, diff --git a/tests/replication/tcp/__init__.py b/tests/replication/tcp/__init__.py new file mode 100644
index 0000000000..1453d04571 --- /dev/null +++ b/tests/replication/tcp/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/tcp/streams/__init__.py b/tests/replication/tcp/streams/__init__.py new file mode 100644
index 0000000000..1453d04571 --- /dev/null +++ b/tests/replication/tcp/streams/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py new file mode 100644
index 0000000000..38b368a972 --- /dev/null +++ b/tests/replication/tcp/streams/_base.py
@@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.commands import ReplicateCommand +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests import unittest +from tests.server import FakeTransport + + +class BaseStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests of the replication streams""" + def prepare(self, reactor, clock, hs): + # build a replication server + server_factory = ReplicationStreamProtocolFactory(self.hs) + self.streamer = server_factory.streamer + server = server_factory.buildProtocol(None) + + # build a replication client, with a dummy handler + self.test_handler = TestReplicationClientHandler() + self.client = ClientReplicationStreamProtocol( + "client", "test", clock, self.test_handler + ) + + # wire them together + self.client.makeConnection(FakeTransport(server, reactor)) + server.makeConnection(FakeTransport(self.client, reactor)) + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump(0.1) + + def replicate_stream(self, stream, token="NOW"): + """Make the client end a REPLICATE command to set up a subscription to a stream""" + self.client.send_command(ReplicateCommand(stream, token)) + + +class TestReplicationClientHandler(object): + """Drop-in for ReplicationClientHandler which just collects RDATA rows""" + def __init__(self): + self.received_rdata_rows = [] + + def get_streams_to_replicate(self): + return {} + + def get_currently_syncing_users(self): + return [] + + def update_connection(self, connection): + pass + + def finished_connecting(self): + pass + + def on_rdata(self, stream_name, token, rows): + for r in rows: + self.received_rdata_rows.append( + (stream_name, token, r) + ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py new file mode 100644
index 0000000000..9aa9dfe82e --- /dev/null +++ b/tests/replication/tcp/streams/test_receipts.py
@@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.streams import ReceiptsStreamRow + +from tests.replication.tcp.streams._base import BaseStreamTestCase + +USER_ID = "@feeling:blue" +ROOM_ID = "!room:blue" +EVENT_ID = "$event:blue" + + +class ReceiptsStreamTestCase(BaseStreamTestCase): + def test_receipt(self): + # make the client subscribe to the receipts stream + self.replicate_stream("receipts", "NOW") + + # tell the master to send a new receipt + self.get_success( + self.hs.get_datastore().insert_receipt( + ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + ) + ) + self.replicate() + + # there should be one RDATA command + rdata_rows = self.test_handler.received_rdata_rows + self.assertEqual(1, len(rdata_rows)) + self.assertEqual(rdata_rows[0][0], "receipts") + row = rdata_rows[0][2] # type: ReceiptsStreamRow + self.assertEqual(ROOM_ID, row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual({"a": 1}, row.data) diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index 0caa4aa802..ef38473bd6 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py
@@ -20,7 +20,7 @@ import json from mock import Mock from synapse.api.constants import UserTypes -from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v1 import admin, events, login, room from tests import unittest @@ -359,7 +359,9 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, login.register_servlets, + events.register_servlets, room.register_servlets, + room.register_deprecated_servlets, ] def prepare(self, reactor, clock, hs): @@ -426,3 +428,65 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): self.store.get_users_in_room(room_id), ) self.assertEqual([], users_in_room) + + @unittest.DEBUG + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + """ + + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (room_id,) + request, channel = self.make_request( + "PUT", + url.encode('ascii'), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_token, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode('ascii'), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert we can no longer peek into the room + self._assert_peek(room_id, expect_code=403) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", + url.encode('ascii'), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"], + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", + url.encode('ascii'), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"], + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 9c401bf300..05b0143c42 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -18,136 +18,11 @@ import time import attr -from twisted.internet import defer - from synapse.api.constants import Membership -from tests import unittest from tests.server import make_request, render -class RestTestCase(unittest.TestCase): - """Contains extra helper functions to quickly and clearly perform a given - REST action, which isn't the focus of the test. - - This subclass assumes there are mock_resource and auth_user_id attributes. - """ - - def __init__(self, *args, **kwargs): - super(RestTestCase, self).__init__(*args, **kwargs) - self.mock_resource = None - self.auth_user_id = None - - @defer.inlineCallbacks - def create_room_as(self, room_creator, is_public=True, tok=None): - temp_id = self.auth_user_id - self.auth_user_id = room_creator - path = "/createRoom" - content = "{}" - if not is_public: - content = '{"visibility":"private"}' - if tok: - path = path + "?access_token=%s" % tok - (code, response) = yield self.mock_resource.trigger("POST", path, content) - self.assertEquals(200, code, msg=str(response)) - self.auth_user_id = temp_id - defer.returnValue(response["room_id"]) - - @defer.inlineCallbacks - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=src, - targ=targ, - tok=tok, - membership=Membership.INVITE, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def join(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.JOIN, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def leave(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.LEAVE, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): - temp_id = self.auth_user_id - self.auth_user_id = src - - path = "/rooms/%s/state/m.room.member/%s" % (room, targ) - if tok: - path = path + "?access_token=%s" % tok - - data = {"membership": membership} - - (code, response) = yield self.mock_resource.trigger( - "PUT", path, json.dumps(data) - ) - self.assertEquals( - expect_code, - code, - msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response), - ) - - self.auth_user_id = temp_id - - @defer.inlineCallbacks - def register(self, user_id): - (code, response) = yield self.mock_resource.trigger( - "POST", - "/register", - json.dumps( - {"user": user_id, "password": "test", "type": "m.login.password"} - ), - ) - self.assertEquals(200, code, msg=response) - defer.returnValue(response) - - @defer.inlineCallbacks - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - if body is None: - body = "body_text_here" - - path = "/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) - content = '{"msgtype":"m.text","body":"%s"}' % body - if tok: - path = path + "?access_token=%s" % tok - - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(expect_code, code, msg=str(response)) - - def assert_dict(self, required, actual): - """Does a partial assert of a dict. - - Args: - required (dict): The keys and value which MUST be in 'actual'. - actual (dict): The test result. Extra keys will not be checked. - """ - for key in required: - self.assertEquals( - required[key], actual[key], msg="%s mismatch. %s" % (key, actual) - ) - - @attr.s class RestHelper(object): """Contains extra helper functions to quickly and clearly perform a given diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 3bd9f1e9c1..be73e718c2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -1,3 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from mock import Mock from twisted.internet import defer @@ -9,16 +24,18 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest -from tests.utils import default_config, setup_test_homeserver -class TestResourceLimitsServerNotices(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs_config = default_config(name="test") +class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): + + def make_homeserver(self, reactor, clock): + hs_config = self.default_config("test") hs_config.server_notices_mxid = "@server:test" - self.hs = yield setup_test_homeserver(self.addCleanup, config=hs_config) + hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + return hs + + def prepare(self, reactor, clock, hs): self.server_notices_sender = self.hs.get_server_notices_sender() # relying on [1] is far from ideal, but the only case where @@ -53,23 +70,21 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._store.get_tags_for_room = Mock(return_value={}) self.hs.config.admin_contact = "mailto:user@test.com" - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_flag_off(self): """Tests cases where the flags indicate nothing to do""" # test hs disabled case self.hs.config.hs_disabled = True - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() # Test when mau limiting disabled self.hs.config.hs_disabled = False self.hs.limit_usage_by_mau = False - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" @@ -81,13 +96,14 @@ class TestResourceLimitsServerNotices(unittest.TestCase): return_value=defer.succeed({"123": mock_event}) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event self._send_notice.assert_called_once() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): - """Test when user has blocked notice, but notice ought to be there (NOOP)""" + """ + Test when user has blocked notice, but notice ought to be there (NOOP) + """ self._rlsn._auth.check_auth_blocking = Mock( side_effect=ResourceLimitError(403, 'foo') ) @@ -98,52 +114,49 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_add_blocked_notice(self): - """Test when user does not have blocked notice, but should have one""" + """ + Test when user does not have blocked notice, but should have one + """ self._rlsn._auth.check_auth_blocking = Mock( side_effect=ResourceLimitError(403, 'foo') ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check contents, but 2 calls == set blocking event self.assertTrue(self._send_notice.call_count == 2) - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): - """Test when user does not have blocked notice, nor should they (NOOP)""" - + """ + Test when user does not have blocked notice, nor should they (NOOP) + """ self._rlsn._auth.check_auth_blocking = Mock() - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): - - """Test when user is not part of the MAU cohort - this should not ever + """ + Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() -class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) +class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = self.hs.get_datastore() self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() @@ -168,26 +181,27 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): self.hs.config.admin_contact = "mailto:user@test.com" - @defer.inlineCallbacks def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) self.store.user_last_seen_monthly_active = Mock(return_value=1000) # Call the function multiple times to ensure we only send the notice once - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Now lets get the last load of messages in the service notice room and # check that there is only one server notice - room_id = yield self.server_notices_manager.get_notice_room_for_user( - self.user_id + room_id = self.get_success( + self.server_notices_manager.get_notice_room_for_user(self.user_id) ) - token = yield self.event_source.get_current_token() - events, _ = yield self.store.get_recent_events_for_room( - room_id, limit=100, end_token=token.room_key + token = self.get_success(self.event_source.get_current_token()) + events, _ = self.get_success( + self.store.get_recent_events_for_room( + room_id, limit=100, end_token=token.room_key + ) ) count = 0 diff --git a/tests/unittest.py b/tests/unittest.py
index 7772a47078..27403de908 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -314,6 +314,9 @@ class HomeserverTestCase(TestCase): """ kwargs = dict(kwargs) kwargs.update(self._hs_args) + if "config" not in kwargs: + config = self.default_config() + kwargs["config"] = config hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() @@ -336,6 +339,15 @@ class HomeserverTestCase(TestCase): self.pump(by=by) return self.successResultOf(d) + def get_failure(self, d, exc): + """ + Run a Deferred and get a Failure from it. The failure must be of the type `exc`. + """ + if not isinstance(d, Deferred): + return d + self.pump() + return self.failureResultOf(d, exc) + def register_user(self, username, password, admin=False): """ Register a user. Requires the Admin API be registered. diff --git a/tests/utils.py b/tests/utils.py
index 67e99a0e40..1b8eeb5167 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,6 +44,10 @@ from synapse.util.logcontext import LoggingContext from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. +# +# When running under postgres, we first create a base database with the name +# POSTGRES_BASE_DB and update it to the current schema. Then, for each test case, we +# create another unique database, using the base database as a template. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False) POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None) @@ -50,28 +55,20 @@ POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None) POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None) POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) +# the dbname we will connect to in order to create the base database. +POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" -def setupdb(): +def setupdb(): # If we're using PostgreSQL, set up the db once if USE_POSTGRES_FOR_TESTS: - pgconfig = { - "name": "psycopg2", - "args": { - "database": POSTGRES_BASE_DB, - "user": POSTGRES_USER, - "host": POSTGRES_HOST, - "password": POSTGRES_PASSWORD, - "cp_min": 1, - "cp_max": 5, - }, - } - config = Mock() - config.password_providers = [] - config.database_config = pgconfig - db_engine = create_engine(pgconfig) + # create a PostgresEngine + db_engine = create_engine({"name": "psycopg2", "args": {}}) + + # connect to postgres to create the base database. db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD + user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True cur = db_conn.cursor() @@ -96,7 +93,8 @@ def setupdb(): def _cleanup(): db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD + user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True cur = db_conn.cursor()