summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py17
-rw-r--r--tests/api/test_ratelimiting.py20
-rw-r--r--tests/config/test_load.py2
-rw-r--r--tests/config/test_room_directory.py4
-rw-r--r--tests/federation/test_federation_sender.py128
-rw-r--r--tests/handlers/test_directory.py59
-rw-r--r--tests/handlers/test_profile.py4
-rw-r--r--tests/handlers/test_register.py143
-rw-r--r--tests/handlers/test_typing.py294
-rw-r--r--tests/handlers/test_user_directory.py362
-rw-r--r--tests/http/test_fedclient.py99
-rw-r--r--tests/push/test_email.py2
-rw-r--r--tests/replication/slave/storage/_base.py4
-rw-r--r--tests/replication/tcp/__init__.py14
-rw-r--r--tests/replication/tcp/streams/__init__.py14
-rw-r--r--tests/replication/tcp/streams/_base.py74
-rw-r--r--tests/replication/tcp/streams/test_receipts.py46
-rw-r--r--tests/rest/client/v1/test_admin.py175
-rw-r--r--tests/rest/client/v1/test_events.py4
-rw-r--r--tests/rest/client/v1/test_login.py163
-rw-r--r--tests/rest/client/v1/test_rooms.py6
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/rest/client/v1/utils.py125
-rw-r--r--tests/rest/client/v2_alpha/test_register.py53
-rw-r--r--tests/rest/media/v1/test_base.py45
-rw-r--r--tests/server.py30
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py93
-rw-r--r--tests/storage/test_user_directory.py17
-rw-r--r--tests/test_mau.py3
-rw-r--r--tests/unittest.py24
-rw-r--r--tests/utils.py150
31 files changed, 1643 insertions, 535 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py

index d77f20e876..d0d36f96fa 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py
@@ -345,6 +345,23 @@ class AuthTestCase(unittest.TestCase): self.assertEquals(e.exception.code, 403) @defer.inlineCallbacks + def test_hs_disabled_no_server_notices_user(self): + """Check that 'hs_disabled_message' works correctly when there is no + server_notices user. + """ + # this should be the default, but we had a bug where the test was doing the wrong + # thing, so let's make it explicit + self.hs.config.server_notices_mxid = None + + self.hs.config.hs_disabled = True + self.hs.config.hs_disabled_message = "Reason for being disabled" + with self.assertRaises(ResourceLimitError) as e: + yield self.auth.check_auth_blocking() + self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.assertEquals(e.exception.code, 403) + + @defer.inlineCallbacks def test_server_notices_mxid_special_cased(self): self.hs.config.hs_disabled = True user = "@user:server" diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 8933fe3b72..30a255d441 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -6,34 +6,34 @@ from tests import unittest class TestRatelimiter(unittest.TestCase): def test_allowed(self): limiter = Ratelimiter() - allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(10., time_allowed) - allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 ) self.assertFalse(allowed) self.assertEquals(10., time_allowed) - allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(20., time_allowed) def test_pruning(self): limiter = Ratelimiter() - allowed, time_allowed = limiter.send_message( - user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertIn("test_id_1", limiter.message_counts) - allowed, time_allowed = limiter.send_message( - user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertNotIn("test_id_1", limiter.message_counts) diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index d5f1777093..6bfc1970ad 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py
@@ -43,7 +43,7 @@ class ConfigLoadingTestCase(unittest.TestCase): self.generate_config() with open(self.file, "r") as f: - raw = yaml.load(f) + raw = yaml.safe_load(f) self.assertIn("macaroon_secret_key", raw) config = HomeServerConfig.load_config("", ["-c", self.file]) diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
index 3dc2631523..47fffcfeb2 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py
@@ -22,7 +22,7 @@ from tests import unittest class RoomDirectoryConfigTestCase(unittest.TestCase): def test_alias_creation_acl(self): - config = yaml.load(""" + config = yaml.safe_load(""" alias_creation_rules: - user_id: "*bob*" alias: "*" @@ -74,7 +74,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): )) def test_room_publish_acl(self): - config = yaml.load(""" + config = yaml.safe_load(""" alias_creation_rules: [] room_list_publication_rules: diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py new file mode 100644
index 0000000000..28e7e27416 --- /dev/null +++ b/tests/federation/test_federation_sender.py
@@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer + +from synapse.types import ReadReceipt + +from tests.unittest import HomeserverTestCase + + +class FederationSenderTestCases(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + return super(FederationSenderTestCases, self).setup_test_homeserver( + state_handler=Mock(spec=["get_current_hosts_in_room"]), + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def test_send_receipts(self): + mock_state_handler = self.hs.get_state_handler() + mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] + + mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction.return_value = defer.succeed({}) + + sender = self.hs.get_federation_sender() + receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + self.successResultOf(sender.send_read_receipt(receipt)) + + self.pump() + + # expect a call to send_transaction + mock_send_transaction.assert_called_once() + json_cb = mock_send_transaction.call_args[0][1] + data = json_cb() + self.assertEqual(data['edus'], [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + }, + }, + }, + }, + }, + ]) + + def test_send_receipts_with_backoff(self): + """Send two receipts in quick succession; the second should be flushed, but + only after 20ms""" + mock_state_handler = self.hs.get_state_handler() + mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] + + mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction.return_value = defer.succeed({}) + + sender = self.hs.get_federation_sender() + receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + self.successResultOf(sender.send_read_receipt(receipt)) + + self.pump() + + # expect a call to send_transaction + mock_send_transaction.assert_called_once() + json_cb = mock_send_transaction.call_args[0][1] + data = json_cb() + self.assertEqual(data['edus'], [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + }, + }, + }, + }, + }, + ]) + mock_send_transaction.reset_mock() + + # send the second RR + receipt = ReadReceipt("room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}) + self.successResultOf(sender.send_read_receipt(receipt)) + self.pump() + mock_send_transaction.assert_not_called() + + self.reactor.advance(19) + mock_send_transaction.assert_not_called() + + self.reactor.advance(10) + mock_send_transaction.assert_called_once() + json_cb = mock_send_transaction.call_args[0][1] + data = json_cb() + self.assertEqual(data['edus'], [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['other_id'], + 'data': {'ts': 1234}, + }, + }, + }, + }, + }, + ]) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 9bf395e923..5b2105bc76 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py
@@ -111,7 +111,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def prepare(self, hs, reactor, clock): + def prepare(self, reactor, clock, hs): # We cheekily override the config to add custom alias creation rules config = {} config["alias_creation_rules"] = [ @@ -151,3 +151,60 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): ) self.render(request) self.assertEquals(200, channel.code, channel.result) + + +class TestRoomListSearchDisabled(unittest.HomeserverTestCase): + user_id = "@test:test" + + servlets = [directory.register_servlets, room.register_servlets] + + def prepare(self, reactor, clock, hs): + room_id = self.helper.create_room_as(self.user_id) + + request, channel = self.make_request( + "PUT", + b"directory/list/room/%s" % (room_id.encode('ascii'),), + b'{}', + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.room_list_handler = hs.get_room_list_handler() + self.directory_handler = hs.get_handlers().directory_handler + + return hs + + def test_disabling_room_list(self): + self.room_list_handler.enable_room_list_search = True + self.directory_handler.enable_room_list_search = True + + # Room list is enabled so we should get some results + request, channel = self.make_request( + "GET", + b"publicRooms", + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertTrue(len(channel.json_body["chunk"]) > 0) + + self.room_list_handler.enable_room_list_search = False + self.directory_handler.enable_room_list_search = False + + # Room list disabled so we should get no results + request, channel = self.make_request( + "GET", + b"publicRooms", + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertTrue(len(channel.json_body["chunk"]) == 0) + + # Room list disabled so we shouldn't be allowed to publish rooms + room_id = self.helper.create_room_as(self.user_id) + request, channel = self.make_request( + "PUT", + b"directory/list/room/%s" % (room_id.encode('ascii'),), + b'{}', + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 80da1c8954..d60c124eec 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -55,11 +55,11 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) self.store = hs.get_datastore() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index c9c1506273..017ea0385e 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -22,8 +22,6 @@ from synapse.api.errors import ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester -from tests.utils import setup_test_homeserver - from .. import unittest @@ -32,176 +30,197 @@ class RegistrationHandlers(object): self.registration_handler = RegistrationHandler(hs) -class RegistrationTestCase(unittest.TestCase): +class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ - @defer.inlineCallbacks - def setUp(self): + def make_homeserver(self, reactor, clock): + hs_config = self.default_config("test") + + # some of the tests rely on us having a user consent version + hs_config.user_consent_version = "test_consent_version" + hs_config.max_mau_value = 50 + + hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + return hs + + def prepare(self, reactor, clock, hs): self.mock_distributor = Mock() self.mock_distributor.declare("registered_user") self.mock_captcha_client = Mock() - self.hs = yield setup_test_homeserver( - self.addCleanup, - expire_access_token=True, - ) self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret') ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.handler = self.hs.get_registration_handler() self.store = self.hs.get_datastore() - self.hs.config.max_mau_value = 50 self.lots_of_users = 100 self.small_number_of_users = 1 self.requester = create_requester("@requester:test") - @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): frank = UserID.from_string("@frank:test") user_id = frank.to_string() requester = create_requester(user_id) - result_user_id, result_token = yield self.handler.get_or_create_user( - requester, frank.localpart, "Frankie" + result_user_id, result_token = self.get_success( + self.handler.get_or_create_user(requester, frank.localpart, "Frankie") ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) self.assertEquals(result_token, 'secret') - @defer.inlineCallbacks def test_if_user_exists(self): store = self.hs.get_datastore() frank = UserID.from_string("@frank:test") - yield store.register( - user_id=frank.to_string(), - token="jkv;g498752-43gj['eamb!-5", - password_hash=None, + self.get_success( + store.register( + user_id=frank.to_string(), + token="jkv;g498752-43gj['eamb!-5", + password_hash=None, + ) ) local_part = frank.localpart user_id = frank.to_string() requester = create_requester(user_id) - result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, None + result_user_id, result_token = self.get_success( + self.handler.get_or_create_user(requester, local_part, None) ) self.assertEquals(result_user_id, user_id) self.assertTrue(result_token is not None) - @defer.inlineCallbacks def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.handler.get_or_create_user(self.requester, 'a', "display_name") + self.get_success( + self.handler.get_or_create_user(self.requester, 'a', "display_name") + ) - @defer.inlineCallbacks def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - yield self.handler.get_or_create_user(self.requester, 'c', "User") + self.get_success(self.handler.get_or_create_user(self.requester, 'c', "User")) - @defer.inlineCallbacks def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user(self.requester, 'b', "display_name") + self.get_failure( + self.handler.get_or_create_user(self.requester, 'b', "display_name"), + ResourceLimitError, + ) self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user(self.requester, 'b', "display_name") + self.get_failure( + self.handler.get_or_create_user(self.requester, 'b', "display_name"), + ResourceLimitError, + ) - @defer.inlineCallbacks def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.lots_of_users) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.register(localpart="local_part") + self.get_failure( + self.handler.register(localpart="local_part"), ResourceLimitError + ) self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield self.handler.register(localpart="local_part") + self.get_failure( + self.handler.register(localpart="local_part"), ResourceLimitError + ) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) - room_id = yield directory_handler.get_association(room_alias) + room_id = self.get_success(directory_handler.get_association(room_alias)) self.assertTrue(room_id['room_id'] in rooms) self.assertEqual(len(rooms), 1) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms_with_no_rooms(self): self.hs.config.auto_join_rooms = [] frank = UserID.from_string("@frank:test") - res = yield self.handler.register(frank.localpart) + res = self.get_success(self.handler.register(frank.localpart)) self.assertEqual(res[0], frank.to_string()) - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_where_room_is_another_domain(self): self.hs.config.auto_join_rooms = ["#room:another"] frank = UserID.from_string("@frank:test") - res = yield self.handler.register(frank.localpart) + res = self.get_success(self.handler.register(frank.localpart)) self.assertEqual(res[0], frank.to_string()) - rooms = yield self.store.get_rooms_for_user(res[0]) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_where_auto_create_is_false(self): self.hs.config.autocreate_auto_join_rooms = False room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='jeff')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_auto_create_auto_join_rooms_when_support_user_exists(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] self.store.is_support_user = Mock(return_value=True) - res = yield self.handler.register(localpart='support') - rooms = yield self.store.get_rooms_for_user(res[0]) + res = self.get_success(self.handler.register(localpart='support')) + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) directory_handler = self.hs.get_handlers().directory_handler room_alias = RoomAlias.from_string(room_alias_str) - with self.assertRaises(SynapseError): - yield directory_handler.get_association(room_alias) + self.get_failure(directory_handler.get_association(room_alias), SynapseError) - @defer.inlineCallbacks def test_auto_create_auto_join_where_no_consent(self): - self.hs.config.user_consent_at_registration = True - self.hs.config.block_events_without_consent_error = "Error" + """Test to ensure that the first user is not auto-joined to a room if + they have not given general consent. + """ + + # Given:- + # * a user must give consent, + # * they have not given that consent + # * The server is configured to auto-join to a room + # (and autocreate if necessary) + + event_creation_handler = self.hs.get_event_creation_handler() + # (Messing with the internals of event_creation_handler is fragile + # but can't see a better way to do this. One option could be to subclass + # the test with custom config.) + event_creation_handler._block_events_without_consent_error = "Error" + event_creation_handler._consent_uri_builder = Mock() room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - res = yield self.handler.register(localpart='jeff') - yield self.handler.post_consent_actions(res[0]) - rooms = yield self.store.get_rooms_for_user(res[0]) + + # When:- + # * the user is registered and post consent actions are called + res = self.get_success(self.handler.register(localpart='jeff')) + self.get_success(self.handler.post_consent_actions(res[0])) + + # Then:- + # * Ensure that they have not been joined to the room + rooms = self.get_success(self.store.get_rooms_for_user(res[0])) self.assertEqual(len(rooms), 0) - @defer.inlineCallbacks def test_register_support_user(self): - res = yield self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + res = self.get_success( + self.handler.register(localpart='user', user_type=UserTypes.SUPPORT) + ) self.assertTrue(self.store.is_support_user(res[0])) - @defer.inlineCallbacks def test_register_not_support_user(self): - res = yield self.handler.register(localpart='user') + res = self.get_success(self.handler.register(localpart='user')) self.assertFalse(self.store.is_support_user(res[0])) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 36e136cded..6460cbc708 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -24,13 +24,17 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.utils import register_federation_servlets -from ..utils import ( - DeferredMockCallable, - MockClock, - MockHttpResource, - setup_test_homeserver, -) +# Some local users to test with +U_APPLE = UserID.from_string("@apple:test") +U_BANANA = UserID.from_string("@banana:test") + +# Remote user +U_ONION = UserID.from_string("@onion:farm") + +# Test room id +ROOM_ID = "a-room" def _expect_edu_transaction(edu_type, content, origin="test"): @@ -46,30 +50,21 @@ def _make_edu_transaction_json(edu_type, content): return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8') -class TypingNotificationsTestCase(unittest.TestCase): - """Tests typing notifications to rooms.""" - - @defer.inlineCallbacks - def setUp(self): - self.clock = MockClock() +class TypingNotificationsTestCase(unittest.HomeserverTestCase): + servlets = [register_federation_servlets] - self.mock_http_client = Mock(spec=[]) - self.mock_http_client.put_json = DeferredMockCallable() + def make_homeserver(self, reactor, clock): + # we mock out the keyring so as to skip the authentication check on the + # federation API call. + mock_keyring = Mock(spec=["verify_json_for_server"]) + mock_keyring.verify_json_for_server.return_value = defer.succeed(True) - self.mock_federation_resource = MockHttpResource() - - mock_notifier = Mock() - self.on_new_event = mock_notifier.on_new_event + # we mock out the federation client too + mock_federation_client = Mock(spec=["put_json"]) + mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) - self.auth = Mock(spec=[]) - self.state_handler = Mock() - - hs = yield setup_test_homeserver( - self.addCleanup, - "test", - auth=self.auth, - clock=self.clock, - datastore=Mock( + hs = self.setup_test_homeserver( + datastore=(Mock( spec=[ # Bits that Federation needs "prep_send_transaction", @@ -82,16 +77,21 @@ class TypingNotificationsTestCase(unittest.TestCase): "get_user_directory_stream_pos", "get_current_state_deltas", ] - ), - state_handler=self.state_handler, - handlers=Mock(), - notifier=mock_notifier, - resource_for_client=Mock(), - resource_for_federation=self.mock_federation_resource, - http_client=self.mock_http_client, - keyring=Mock(), + )), + notifier=Mock(), + http_client=mock_federation_client, + keyring=mock_keyring, ) + return hs + + def prepare(self, reactor, clock, hs): + # the tests assume that we are starting at unix time 1000 + reactor.pump((1000, )) + + mock_notifier = hs.get_notifier() + self.on_new_event = mock_notifier.on_new_event + self.handler = hs.get_typing_handler() self.event_source = hs.get_event_sources().sources["typing"] @@ -109,13 +109,12 @@ class TypingNotificationsTestCase(unittest.TestCase): self.datastore.get_received_txn_response = get_received_txn_response - self.room_id = "a-room" - self.room_members = [] def check_joined_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") + hs.get_auth().check_joined_room = check_joined_room def get_joined_hosts_for_room(room_id): return set(member.domain for member in self.room_members) @@ -124,8 +123,7 @@ class TypingNotificationsTestCase(unittest.TestCase): def get_current_user_in_room(room_id): return set(str(u) for u in self.room_members) - - self.state_handler.get_current_user_in_room = get_current_user_in_room + hs.get_state_handler().get_current_user_in_room = get_current_user_in_room self.datastore.get_user_directory_stream_pos.return_value = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam @@ -134,230 +132,210 @@ class TypingNotificationsTestCase(unittest.TestCase): self.datastore.get_current_state_deltas.return_value = None - self.auth.check_joined_room = check_joined_room - self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None - # Some local users to test with - self.u_apple = UserID.from_string("@apple:test") - self.u_banana = UserID.from_string("@banana:test") - - # Remote user - self.u_onion = UserID.from_string("@onion:farm") - - @defer.inlineCallbacks def test_started_typing_local(self): - self.room_members = [self.u_apple, self.u_banana] + self.room_members = [U_APPLE, U_BANANA] self.assertEquals(self.event_source.get_current_key(), 0) - yield self.handler.started_typing( - target_user=self.u_apple, - auth_user=self.u_apple, - room_id=self.room_id, + self.successResultOf(self.handler.started_typing( + target_user=U_APPLE, + auth_user=U_APPLE, + room_id=ROOM_ID, timeout=20000, - ) + )) self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[self.room_id])] + [call('typing_key', 1, rooms=[ROOM_ID])] ) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.u_apple.to_string()]}, + "room_id": ROOM_ID, + "content": {"user_ids": [U_APPLE.to_string()]}, } ], ) - @defer.inlineCallbacks def test_started_typing_remote_send(self): - self.room_members = [self.u_apple, self.u_onion] - - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call( - "farm", - path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu_transaction( - "m.typing", - content={ - "room_id": self.room_id, - "user_id": self.u_apple.to_string(), - "typing": True, - }, - ), - json_data_callback=ANY, - long_retries=True, - backoff_on_404=True, - ), - defer.succeed((200, "OK")), - ) + self.room_members = [U_APPLE, U_ONION] - yield self.handler.started_typing( - target_user=self.u_apple, - auth_user=self.u_apple, - room_id=self.room_id, + self.successResultOf(self.handler.started_typing( + target_user=U_APPLE, + auth_user=U_APPLE, + room_id=ROOM_ID, timeout=20000, - ) + )) - yield put_json.await_calls() + put_json = self.hs.get_http_client().put_json + put_json.assert_called_once_with( + "farm", + path="/_matrix/federation/v1/send/1000000", + data=_expect_edu_transaction( + "m.typing", + content={ + "room_id": ROOM_ID, + "user_id": U_APPLE.to_string(), + "typing": True, + }, + ), + json_data_callback=ANY, + long_retries=True, + backoff_on_404=True, + try_trailing_slash_on_400=True, + ) - @defer.inlineCallbacks def test_started_typing_remote_recv(self): - self.room_members = [self.u_apple, self.u_onion] + self.room_members = [U_APPLE, U_ONION] self.assertEquals(self.event_source.get_current_key(), 0) - (code, response) = yield self.mock_federation_resource.trigger( + (request, channel) = self.make_request( "PUT", - "/_matrix/federation/v1/send/1000000/", + "/_matrix/federation/v1/send/1000000", _make_edu_transaction_json( "m.typing", content={ - "room_id": self.room_id, - "user_id": self.u_onion.to_string(), + "room_id": ROOM_ID, + "user_id": U_ONION.to_string(), "typing": True, }, ), federation_auth_origin=b'farm', ) + self.render(request) + self.assertEqual(channel.code, 200) self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[self.room_id])] + [call('typing_key', 1, rooms=[ROOM_ID])] ) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.u_onion.to_string()]}, + "room_id": ROOM_ID, + "content": {"user_ids": [U_ONION.to_string()]}, } ], ) - @defer.inlineCallbacks def test_stopped_typing(self): - self.room_members = [self.u_apple, self.u_banana, self.u_onion] - - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call( - "farm", - path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu_transaction( - "m.typing", - content={ - "room_id": self.room_id, - "user_id": self.u_apple.to_string(), - "typing": False, - }, - ), - json_data_callback=ANY, - long_retries=True, - backoff_on_404=True, - ), - defer.succeed((200, "OK")), - ) + self.room_members = [U_APPLE, U_BANANA, U_ONION] # Gut-wrenching from synapse.handlers.typing import RoomMember - member = RoomMember(self.room_id, self.u_apple.to_string()) + member = RoomMember(ROOM_ID, U_APPLE.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()]) + self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()]) self.assertEquals(self.event_source.get_current_key(), 0) - yield self.handler.stopped_typing( - target_user=self.u_apple, auth_user=self.u_apple, room_id=self.room_id - ) + self.successResultOf(self.handler.stopped_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + )) self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[self.room_id])] + [call('typing_key', 1, rooms=[ROOM_ID])] ) - yield put_json.await_calls() + put_json = self.hs.get_http_client().put_json + put_json.assert_called_once_with( + "farm", + path="/_matrix/federation/v1/send/1000000", + data=_expect_edu_transaction( + "m.typing", + content={ + "room_id": ROOM_ID, + "user_id": U_APPLE.to_string(), + "typing": False, + }, + ), + json_data_callback=ANY, + long_retries=True, + backoff_on_404=True, + try_trailing_slash_on_400=True, + ) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, + "room_id": ROOM_ID, "content": {"user_ids": []}, } ], ) - @defer.inlineCallbacks def test_typing_timeout(self): - self.room_members = [self.u_apple, self.u_banana] + self.room_members = [U_APPLE, U_BANANA] self.assertEquals(self.event_source.get_current_key(), 0) - yield self.handler.started_typing( - target_user=self.u_apple, - auth_user=self.u_apple, - room_id=self.room_id, + self.successResultOf(self.handler.started_typing( + target_user=U_APPLE, + auth_user=U_APPLE, + room_id=ROOM_ID, timeout=10000, - ) + )) self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[self.room_id])] + [call('typing_key', 1, rooms=[ROOM_ID])] ) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.u_apple.to_string()]}, + "room_id": ROOM_ID, + "content": {"user_ids": [U_APPLE.to_string()]}, } ], ) - self.clock.advance_time(16) + self.reactor.pump([16, ]) self.on_new_event.assert_has_calls( - [call('typing_key', 2, rooms=[self.room_id])] + [call('typing_key', 2, rooms=[ROOM_ID])] ) self.assertEquals(self.event_source.get_current_key(), 2) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=1 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=1 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, + "room_id": ROOM_ID, "content": {"user_ids": []}, } ], @@ -365,29 +343,29 @@ class TypingNotificationsTestCase(unittest.TestCase): # SYN-230 - see if we can still set after timeout - yield self.handler.started_typing( - target_user=self.u_apple, - auth_user=self.u_apple, - room_id=self.room_id, + self.successResultOf(self.handler.started_typing( + target_user=U_APPLE, + auth_user=U_APPLE, + room_id=ROOM_ID, timeout=10000, - ) + )) self.on_new_event.assert_has_calls( - [call('typing_key', 3, rooms=[self.room_id])] + [call('typing_key', 3, rooms=[ROOM_ID])] ) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.u_apple.to_string()]}, + "room_id": ROOM_ID, + "content": {"user_ids": [U_APPLE.to_string()]}, } ], ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 11f2bae698..f1d0aa42b6 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -14,78 +14,358 @@ # limitations under the License. from mock import Mock -from twisted.internet import defer - from synapse.api.constants import UserTypes -from synapse.handlers.user_directory import UserDirectoryHandler +from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v2_alpha import user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest -from tests.utils import setup_test_homeserver -class UserDirectoryHandlers(object): - def __init__(self, hs): - self.user_directory_handler = UserDirectoryHandler(hs) +class UserDirectoryTestCase(unittest.HomeserverTestCase): + """ + Tests the UserDirectoryHandler. + """ + servlets = [ + login.register_servlets, + admin.register_servlets, + room.register_servlets, + ] -class UserDirectoryTestCase(unittest.TestCase): - """ Tests the UserDirectoryHandler. """ + def make_homeserver(self, reactor, clock): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - self.store = hs.get_datastore() - hs.handlers = UserDirectoryHandlers(hs) + config = self.default_config() + config.update_user_directory = True + return self.setup_test_homeserver(config=config) - self.handler = hs.get_handlers().user_directory_handler + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() - @defer.inlineCallbacks def test_handle_local_profile_change_with_support_user(self): support_user_id = "@support:test" - yield self.store.register( - user_id=support_user_id, - token="123", - password_hash=None, - user_type=UserTypes.SUPPORT + self.get_success( + self.store.register( + user_id=support_user_id, + token="123", + password_hash=None, + user_type=UserTypes.SUPPORT, + ) ) - yield self.handler.handle_local_profile_change(support_user_id, None) - profile = yield self.store.get_user_in_directory(support_user_id) + self.get_success( + self.handler.handle_local_profile_change(support_user_id, None) + ) + profile = self.get_success(self.store.get_user_in_directory(support_user_id)) self.assertTrue(profile is None) display_name = 'display_name' - profile_info = ProfileInfo( - avatar_url='avatar_url', - display_name=display_name, - ) + profile_info = ProfileInfo(avatar_url='avatar_url', display_name=display_name) regular_user_id = '@regular:test' - yield self.handler.handle_local_profile_change(regular_user_id, profile_info) - profile = yield self.store.get_user_in_directory(regular_user_id) + self.get_success( + self.handler.handle_local_profile_change(regular_user_id, profile_info) + ) + profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) self.assertTrue(profile['display_name'] == display_name) - @defer.inlineCallbacks def test_handle_user_deactivated_support_user(self): s_user_id = "@support:test" - self.store.register( - user_id=s_user_id, - token="123", - password_hash=None, - user_type=UserTypes.SUPPORT + self.get_success( + self.store.register( + user_id=s_user_id, + token="123", + password_hash=None, + user_type=UserTypes.SUPPORT, + ) ) self.store.remove_from_user_dir = Mock() self.store.remove_from_user_in_public_room = Mock() - yield self.handler.handle_user_deactivated(s_user_id) + self.get_success(self.handler.handle_user_deactivated(s_user_id)) self.store.remove_from_user_dir.not_called() self.store.remove_from_user_in_public_room.not_called() - @defer.inlineCallbacks def test_handle_user_deactivated_regular_user(self): r_user_id = "@regular:test" - self.store.register(user_id=r_user_id, token="123", password_hash=None) + self.get_success( + self.store.register(user_id=r_user_id, token="123", password_hash=None) + ) self.store.remove_from_user_dir = Mock() - self.store.remove_from_user_in_public_room = Mock() - yield self.handler.handle_user_deactivated(r_user_id) + self.get_success(self.handler.handle_user_deactivated(r_user_id)) self.store.remove_from_user_dir.called_once_with(r_user_id) - self.store.remove_from_user_in_public_room.assert_called_once_with(r_user_id) + + def test_private_room(self): + """ + A user can be searched for only by people that are either in a public + room, or that share a private chat. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + u3 = self.register_user("user3", "pass") + + # We do not add users to the directory until they join a room. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Check we have populated the database correctly. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual( + self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + ) + self.assertEqual(public_users, []) + + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # We get NO search results when searching for user2 by user3. + s = self.get_success(self.handler.search_users(u3, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + # We get NO search results when searching for user3 by user1. + s = self.get_success(self.handler.search_users(u1, "user3", 10)) + self.assertEqual(len(s["results"]), 0) + + # User 2 then leaves. + self.helper.leave(room, user=u2, tok=u2_token) + + # Check we have removed the values. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual(self._compress_shared(shares_private), set()) + self.assertEqual(public_users, []) + + # User1 now gets no search results for any of the other users. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + s = self.get_success(self.handler.search_users(u1, "user3", 10)) + self.assertEqual(len(s["results"]), 0) + + def _compress_shared(self, shared): + """ + Compress a list of users who share rooms dicts to a list of tuples. + """ + r = set() + for i in shared: + r.add((i["user_id"], i["other_user_id"], i["room_id"])) + return r + + def get_users_in_public_rooms(self): + r = self.get_success( + self.store._simple_select_list( + "users_in_public_rooms", None, ("user_id", "room_id") + ) + ) + retval = [] + for i in r: + retval.append((i["user_id"], i["room_id"])) + return retval + + def get_users_who_share_private_rooms(self): + return self.get_success( + self.store._simple_select_list( + "users_who_share_private_rooms", + None, + ["user_id", "other_user_id", "room_id"], + ) + ) + + def _add_background_updates(self): + """ + Add the background updates we need to run. + """ + # Ugh, have to reset this flag + self.store._all_done = False + + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_createtables", + "progress_json": "{}", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_process_rooms", + "progress_json": "{}", + "depends_on": "populate_user_directory_createtables", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_process_users", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_rooms", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_user_directory_cleanup", + "progress_json": "{}", + "depends_on": "populate_user_directory_process_users", + }, + ) + ) + + def test_initial(self): + """ + The user directory's initial handler correctly updates the search tables. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + u3 = self.register_user("user3", "pass") + u3_token = self.login(u3, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token) + self.helper.join(private_room, user=u3, tok=u3_token) + + self.get_success(self.store.update_user_directory_stream_pos(None)) + self.get_success(self.store.delete_all_from_user_dir()) + + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + # Nothing updated yet + self.assertEqual(shares_private, []) + self.assertEqual(public_users, []) + + # Do the initial population of the user directory via the background update + self._add_background_updates() + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + # User 1 and User 2 are in the same public room + self.assertEqual(set(public_users), set([(u1, room), (u2, room)])) + + # User 1 and User 3 share private rooms + self.assertEqual( + self._compress_shared(shares_private), + set([(u1, u3, private_room), (u3, u1, private_room)]), + ) + + def test_initial_share_all_users(self): + """ + Search all users = True means that a user does not have to share a + private room with the searching user or be in a public room to be search + visible. + """ + self.handler.search_all_users = True + self.hs.config.user_directory_search_all_users = True + + u1 = self.register_user("user1", "pass") + self.register_user("user2", "pass") + u3 = self.register_user("user3", "pass") + + # Wipe the user dir + self.get_success(self.store.update_user_directory_stream_pos(None)) + self.get_success(self.store.delete_all_from_user_dir()) + + # Do the initial population of the user directory via the background update + self._add_background_updates() + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + # No users share rooms + self.assertEqual(public_users, []) + self.assertEqual(self._compress_shared(shares_private), set([])) + + # Despite not sharing a room, search_all_users means we get a search + # result. + s = self.get_success(self.handler.search_users(u1, u3, 10)) + self.assertEqual(len(s["results"]), 1) + + # We can find the other two users + s = self.get_success(self.handler.search_users(u1, "user", 10)) + self.assertEqual(len(s["results"]), 2) + + # Registering a user and then searching for them works. + u4 = self.register_user("user4", "pass") + s = self.get_success(self.handler.search_users(u1, u4, 10)) + self.assertEqual(len(s["results"]), 1) + + +class TestUserDirSearchDisabled(unittest.HomeserverTestCase): + user_id = "@test:test" + + servlets = [ + user_directory.register_servlets, + room.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config.update_user_directory = True + hs = self.setup_test_homeserver(config=config) + + self.config = hs.config + + return hs + + def test_disabling_room_list(self): + self.config.user_directory_search_enabled = True + + # First we create a room with another user so that user dir is non-empty + # for our user + self.helper.create_room_as(self.user_id) + u2 = self.register_user("user2", "pass") + room = self.helper.create_room_as(self.user_id) + self.helper.join(room, user=u2) + + # Assert user directory is not empty + request, channel = self.make_request( + "POST", + b"user_directory/search", + b'{"search_term":"user2"}', + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertTrue(len(channel.json_body["results"]) > 0) + + # Disable user directory and check search returns nothing + self.config.user_directory_search_enabled = False + request, channel = self.make_request( + "POST", + b"user_directory/search", + b'{"search_term":"user2"}', + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertTrue(len(channel.json_body["results"]) == 0) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index b03b37affe..cd8e086f86 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py
@@ -268,6 +268,105 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, TimeoutError) + def test_client_requires_trailing_slashes(self): + """ + If a connection is made to a client but the client rejects it due to + requiring a trailing slash. We need to retry the request with a + trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. + """ + d = self.cl.get_json( + "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, + ) + + # Send the request + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Clear the original request data before sending a response + conn.clear() + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 400 Bad Request\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 59\r\n" + b"\r\n" + b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}' + ) + + # We should get another request with a trailing slash + self.assertRegex(conn.value(), b"^GET /foo/bar/") + + # Send a happy response this time + client.dataReceived( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b'{}' + ) + + # We should get a successful response + r = self.successResultOf(d) + self.assertEqual(r, {}) + + def test_client_does_not_retry_on_400_plus(self): + """ + Another test for trailing slashes but now test that we don't retry on + trailing slashes on a non-400/M_UNRECOGNIZED response. + + See test_client_requires_trailing_slashes() for context. + """ + d = self.cl.get_json( + "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, + ) + + # Send the request + self.pump() + + # there should have been a call to connectTCP + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (_host, _port, factory, _timeout, _bindAddress) = clients[0] + + # complete the connection and wire it up to a fake transport + client = factory.buildProtocol(None) + conn = StringTransport() + client.makeConnection(conn) + + # that should have made it send the request to the connection + self.assertRegex(conn.value(), b"^GET /foo/bar") + + # Clear the original request data before sending a response + conn.clear() + + # Send the HTTP response + client.dataReceived( + b"HTTP/1.1 404 Not Found\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 2\r\n" + b"\r\n" + b"{}" + ) + + # We should not get another request + self.assertEqual(conn.value(), b"") + + # We should get a 404 failure response + self.failureResultOf(d) + def test_client_sends_body(self): self.cl.post_json( "testserv:8008", "foo/bar", timeout=10000, diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 50ee6910d1..be3fed8de3 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py
@@ -63,8 +63,10 @@ class EmailPusherTests(HomeserverTestCase): config.email_smtp_port = 20 config.require_transport_security = False config.email_smtp_user = None + config.email_smtp_pass = None config.email_app_name = "Matrix" config.email_notif_from = "test@example.com" + config.email_riot_base_url = None hs = self.setup_test_homeserver(config=config, sendmail=sendmail) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 9e9fbbfe93..524af4f8d1 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py
@@ -31,10 +31,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver( "blue", federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) - hs.get_ratelimiter().send_message.return_value = (True, 0) + hs.get_ratelimiter().can_do_action.return_value = (True, 0) return hs diff --git a/tests/replication/tcp/__init__.py b/tests/replication/tcp/__init__.py new file mode 100644
index 0000000000..1453d04571 --- /dev/null +++ b/tests/replication/tcp/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/tcp/streams/__init__.py b/tests/replication/tcp/streams/__init__.py new file mode 100644
index 0000000000..1453d04571 --- /dev/null +++ b/tests/replication/tcp/streams/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py new file mode 100644
index 0000000000..38b368a972 --- /dev/null +++ b/tests/replication/tcp/streams/_base.py
@@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.commands import ReplicateCommand +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests import unittest +from tests.server import FakeTransport + + +class BaseStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests of the replication streams""" + def prepare(self, reactor, clock, hs): + # build a replication server + server_factory = ReplicationStreamProtocolFactory(self.hs) + self.streamer = server_factory.streamer + server = server_factory.buildProtocol(None) + + # build a replication client, with a dummy handler + self.test_handler = TestReplicationClientHandler() + self.client = ClientReplicationStreamProtocol( + "client", "test", clock, self.test_handler + ) + + # wire them together + self.client.makeConnection(FakeTransport(server, reactor)) + server.makeConnection(FakeTransport(self.client, reactor)) + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump(0.1) + + def replicate_stream(self, stream, token="NOW"): + """Make the client end a REPLICATE command to set up a subscription to a stream""" + self.client.send_command(ReplicateCommand(stream, token)) + + +class TestReplicationClientHandler(object): + """Drop-in for ReplicationClientHandler which just collects RDATA rows""" + def __init__(self): + self.received_rdata_rows = [] + + def get_streams_to_replicate(self): + return {} + + def get_currently_syncing_users(self): + return [] + + def update_connection(self, connection): + pass + + def finished_connecting(self): + pass + + def on_rdata(self, stream_name, token, rows): + for r in rows: + self.received_rdata_rows.append( + (stream_name, token, r) + ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py new file mode 100644
index 0000000000..9aa9dfe82e --- /dev/null +++ b/tests/replication/tcp/streams/test_receipts.py
@@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.replication.tcp.streams import ReceiptsStreamRow + +from tests.replication.tcp.streams._base import BaseStreamTestCase + +USER_ID = "@feeling:blue" +ROOM_ID = "!room:blue" +EVENT_ID = "$event:blue" + + +class ReceiptsStreamTestCase(BaseStreamTestCase): + def test_receipt(self): + # make the client subscribe to the receipts stream + self.replicate_stream("receipts", "NOW") + + # tell the master to send a new receipt + self.get_success( + self.hs.get_datastore().insert_receipt( + ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} + ) + ) + self.replicate() + + # there should be one RDATA command + rdata_rows = self.test_handler.received_rdata_rows + self.assertEqual(1, len(rdata_rows)) + self.assertEqual(rdata_rows[0][0], "receipts") + row = rdata_rows[0][2] # type: ReceiptsStreamRow + self.assertEqual(ROOM_ID, row.room_id) + self.assertEqual("m.read", row.receipt_type) + self.assertEqual(USER_ID, row.user_id) + self.assertEqual(EVENT_ID, row.event_id) + self.assertEqual({"a": 1}, row.data) diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
index 407bf0ac4c..ef38473bd6 100644 --- a/tests/rest/client/v1/test_admin.py +++ b/tests/rest/client/v1/test_admin.py
@@ -20,14 +20,48 @@ import json from mock import Mock from synapse.api.constants import UserTypes -from synapse.rest.client.v1.admin import register_servlets +from synapse.rest.client.v1 import admin, events, login, room from tests import unittest +class VersionTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + ] + + url = '/_matrix/client/r0/admin/server_version' + + def test_version_string(self): + self.register_user("admin", "pass", admin=True) + self.admin_token = self.login("admin", "pass") + + request, channel = self.make_request("GET", self.url, + access_token=self.admin_token) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), + msg=channel.result["body"]) + self.assertEqual({'server_version', 'python_version'}, + set(channel.json_body.keys())) + + def test_inaccessible_to_non_admins(self): + self.register_user("unprivileged-user", "pass", admin=False) + user_token = self.login("unprivileged-user", "pass") + + request, channel = self.make_request("GET", self.url, + access_token=user_token) + self.render(request) + + self.assertEqual(403, int(channel.result['code']), + msg=channel.result['body']) + + class UserRegisterTestCase(unittest.HomeserverTestCase): - servlets = [register_servlets] + servlets = [admin.register_servlets] def make_homeserver(self, reactor, clock): @@ -319,3 +353,140 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual('Invalid user type', channel.json_body["error"]) + + +class ShutdownRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = ( + "http://example.com" + ) + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success( + self.store.user_set_consent_version(self.admin_user, "1"), + ) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + """ + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Assert one user in room + users_in_room = self.get_success( + self.store.get_users_in_room(room_id), + ) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + room_id, + body="foo", tok=self.other_user_token, expect_code=403, + ) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode('ascii'), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert there is now no longer anyone in the room + users_in_room = self.get_success( + self.store.get_users_in_room(room_id), + ) + self.assertEqual([], users_in_room) + + @unittest.DEBUG + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + """ + + self.event_creation_handler._block_events_without_consent_error = None + + room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (room_id,) + request, channel = self.make_request( + "PUT", + url.encode('ascii'), + json.dumps({"history_visibility": "world_readable"}), + access_token=self.other_user_token, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Test that the admin can still send shutdown + url = "admin/shutdown_room/" + room_id + request, channel = self.make_request( + "POST", + url.encode('ascii'), + json.dumps({"new_room_user_id": self.admin_user}), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Assert we can no longer peek into the room + self._assert_peek(room_id, expect_code=403) + + def _assert_peek(self, room_id, expect_code): + """Assert that the admin user can (or cannot) peek into the room. + """ + + url = "rooms/%s/initialSync" % (room_id,) + request, channel = self.make_request( + "GET", + url.encode('ascii'), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"], + ) + + url = "events?timeout=0&room_id=" + room_id + request, channel = self.make_request( + "GET", + url.encode('ascii'), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual( + expect_code, int(channel.result["code"]), msg=channel.result["body"], + ) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 483bebc832..36d8547275 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py
@@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): config.auto_join_rooms = [] hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["send_message"]) + config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) ) self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py new file mode 100644
index 0000000000..86312f1096 --- /dev/null +++ b/tests/rest/client/v1/test_login.py
@@ -0,0 +1,163 @@ +import json + +from synapse.rest.client.v1 import admin, login + +from tests import unittest + +LOGIN_URL = b"/_matrix/client/r0/login" + + +class LoginRestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + self.hs = self.setup_test_homeserver() + self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] + self.hs.config.auto_join_rooms = [] + self.hs.config.enable_registration_captcha = False + + return self.hs + + def test_POST_ratelimiting_per_address(self): + self.hs.config.rc_login_address.burst_count = 5 + self.hs.config.rc_login_address.per_second = 0.17 + + # Create different users so we're sure not to be bothered by the per-user + # ratelimiter. + for i in range(0, 6): + self.register_user("kermit" + str(i), "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "kermit" + str(i), + }, + "password": "monkey", + } + request_data = json.dumps(params) + request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + self.render(request) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.) + + params = { + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "kermit" + str(i), + }, + "password": "monkey", + } + request_data = json.dumps(params) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_POST_ratelimiting_per_account(self): + self.hs.config.rc_login_account.burst_count = 5 + self.hs.config.rc_login_account.per_second = 0.17 + + self.register_user("kermit", "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "kermit", + }, + "password": "monkey", + } + request_data = json.dumps(params) + request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + self.render(request) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.) + + params = { + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "kermit", + }, + "password": "monkey", + } + request_data = json.dumps(params) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_POST_ratelimiting_per_account_failed_attempts(self): + self.hs.config.rc_login_failed_attempts.burst_count = 5 + self.hs.config.rc_login_failed_attempts.per_second = 0.17 + + self.register_user("kermit", "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "kermit", + }, + "password": "notamonkey", + } + request_data = json.dumps(params) + request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + self.render(request) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"403", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.) + + params = { + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "kermit", + }, + "password": "notamonkey", + } + request_data = json.dumps(params) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + + self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index a824be9a62..015c144248 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py
@@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.ratelimiter = self.hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) @@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase): # auth as user_id now self.helper.auth_user_id = self.user_id - def test_send_message(self): + def test_can_do_action(self): msg_content = b'{"msgtype":"m.text","body":"hello"}' seq = iter(range(100)) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 0ad814c5e5..30fb77bac8 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py
@@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.event_source = hs.get_event_sources().sources["typing"] self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 9c401bf300..05b0143c42 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -18,136 +18,11 @@ import time import attr -from twisted.internet import defer - from synapse.api.constants import Membership -from tests import unittest from tests.server import make_request, render -class RestTestCase(unittest.TestCase): - """Contains extra helper functions to quickly and clearly perform a given - REST action, which isn't the focus of the test. - - This subclass assumes there are mock_resource and auth_user_id attributes. - """ - - def __init__(self, *args, **kwargs): - super(RestTestCase, self).__init__(*args, **kwargs) - self.mock_resource = None - self.auth_user_id = None - - @defer.inlineCallbacks - def create_room_as(self, room_creator, is_public=True, tok=None): - temp_id = self.auth_user_id - self.auth_user_id = room_creator - path = "/createRoom" - content = "{}" - if not is_public: - content = '{"visibility":"private"}' - if tok: - path = path + "?access_token=%s" % tok - (code, response) = yield self.mock_resource.trigger("POST", path, content) - self.assertEquals(200, code, msg=str(response)) - self.auth_user_id = temp_id - defer.returnValue(response["room_id"]) - - @defer.inlineCallbacks - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=src, - targ=targ, - tok=tok, - membership=Membership.INVITE, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def join(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.JOIN, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def leave(self, room=None, user=None, expect_code=200, tok=None): - yield self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.LEAVE, - expect_code=expect_code, - ) - - @defer.inlineCallbacks - def change_membership(self, room, src, targ, membership, tok=None, expect_code=200): - temp_id = self.auth_user_id - self.auth_user_id = src - - path = "/rooms/%s/state/m.room.member/%s" % (room, targ) - if tok: - path = path + "?access_token=%s" % tok - - data = {"membership": membership} - - (code, response) = yield self.mock_resource.trigger( - "PUT", path, json.dumps(data) - ) - self.assertEquals( - expect_code, - code, - msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response), - ) - - self.auth_user_id = temp_id - - @defer.inlineCallbacks - def register(self, user_id): - (code, response) = yield self.mock_resource.trigger( - "POST", - "/register", - json.dumps( - {"user": user_id, "password": "test", "type": "m.login.password"} - ), - ) - self.assertEquals(200, code, msg=response) - defer.returnValue(response) - - @defer.inlineCallbacks - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - if body is None: - body = "body_text_here" - - path = "/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) - content = '{"msgtype":"m.text","body":"%s"}' % body - if tok: - path = path + "?access_token=%s" % tok - - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(expect_code, code, msg=str(response)) - - def assert_dict(self, required, actual): - """Does a partial assert of a dict. - - Args: - required (dict): The keys and value which MUST be in 'actual'. - actual (dict): The test result. Extra keys will not be checked. - """ - for key in required: - self.assertEquals( - required[key], actual[key], msg="%s mismatch. %s" % (key, actual) - ) - - @attr.s class RestHelper(object): """Contains extra helper functions to quickly and clearly perform a given diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 906b348d3e..a45e6e5e1f 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -20,6 +20,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] self.hs.config.enable_registration_captcha = False + self.hs.config.allow_guest_access = True return self.hs @@ -28,7 +29,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): as_token = "i_am_an_app_service" appservice = ApplicationService( - as_token, self.hs.config.hostname, + as_token, self.hs.config.server_name, id="1234", namespaces={ "users": [{"regex": r"@as_user.*", "exclusive": True}], @@ -130,3 +131,53 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") + + def test_POST_ratelimiting_guest(self): + self.hs.config.rc_registration.burst_count = 5 + self.hs.config.rc_registration.per_second = 0.17 + + for i in range(0, 6): + url = self.url + b"?kind=guest" + request, channel = self.make_request(b"POST", url, b"{}") + self.render(request) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.) + + request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + self.render(request) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_POST_ratelimiting(self): + self.hs.config.rc_registration.burst_count = 5 + self.hs.config.rc_registration.per_second = 0.17 + + for i in range(0, 6): + params = { + "username": "kermit" + str(i), + "password": "monkey", + "device_id": "frogfone", + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + request, channel = self.make_request(b"POST", self.url, request_data) + self.render(request) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.) + + request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + self.render(request) + + self.assertEquals(channel.result["code"], b"200", channel.result) diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py new file mode 100644
index 0000000000..af8f74eb42 --- /dev/null +++ b/tests/rest/media/v1/test_base.py
@@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.media.v1._base import get_filename_from_headers + +from tests import unittest + + +class GetFileNameFromHeadersTests(unittest.TestCase): + # input -> expected result + TEST_CASES = { + b"inline; filename=abc.txt": u"abc.txt", + b'inline; filename="azerty"': u"azerty", + b'inline; filename="aze%20rty"': u"aze%20rty", + b'inline; filename="aze\"rty"': u'aze"rty', + b'inline; filename="azer;ty"': u"azer;ty", + + b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar", + } + + def tests(self): + for hdr, expected in self.TEST_CASES.items(): + res = get_filename_from_headers( + { + b'Content-Disposition': [hdr], + }, + ) + self.assertEqual( + res, expected, + "expected output for %s to be %s but was %s" % ( + hdr, expected, res, + ) + ) diff --git a/tests/server.py b/tests/server.py
index fc1e76d146..ea26dea623 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -119,14 +119,7 @@ class FakeSite: server_version_string = b"1" site_tag = "test" - - @property - def access_logger(self): - class FakeLogger: - def info(self, *args, **kwargs): - pass - - return FakeLogger() + access_logger = logging.getLogger("synapse.access.http.fake") def make_request( @@ -137,6 +130,7 @@ def make_request( access_token=None, request=SynapseRequest, shorthand=True, + federation_auth_origin=None, ): """ Make a web request using the given method and path, feed it the @@ -150,9 +144,11 @@ def make_request( a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. + federation_auth_origin (bytes|None): if set to not-None, we will add a fake + Authorization header pretenting to be the given server name. Returns: - A synapse.http.site.SynapseRequest. + Tuple[synapse.http.site.SynapseRequest, channel] """ if not isinstance(method, bytes): method = method.encode('ascii') @@ -184,6 +180,11 @@ def make_request( b"Authorization", b"Bearer " + access_token.encode('ascii') ) + if federation_auth_origin is not None: + req.requestHeaders.addRawHeader( + b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,) + ) + if content: req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") @@ -288,9 +289,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): **kwargs ) - pool.runWithConnection = runWithConnection - pool.runInteraction = runInteraction - class ThreadPool: """ Threadless thread pool. @@ -316,8 +314,12 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): return d clock.threadpool = ThreadPool() - pool.threadpool = ThreadPool() - pool.running = True + + if pool: + pool.runWithConnection = runWithConnection + pool.runInteraction = runInteraction + pool.threadpool = ThreadPool() + pool.running = True return d diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index b1551df7ca..be73e718c2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -1,3 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from mock import Mock from twisted.internet import defer @@ -9,13 +24,18 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest -from tests.utils import setup_test_homeserver -class TestResourceLimitsServerNotices(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) +class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): + + def make_homeserver(self, reactor, clock): + hs_config = self.default_config("test") + hs_config.server_notices_mxid = "@server:test" + + hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) + return hs + + def prepare(self, reactor, clock, hs): self.server_notices_sender = self.hs.get_server_notices_sender() # relying on [1] is far from ideal, but the only case where @@ -50,23 +70,21 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._store.get_tags_for_room = Mock(return_value={}) self.hs.config.admin_contact = "mailto:user@test.com" - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_flag_off(self): """Tests cases where the flags indicate nothing to do""" # test hs disabled case self.hs.config.hs_disabled = True - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() # Test when mau limiting disabled self.hs.config.hs_disabled = False self.hs.limit_usage_by_mau = False - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" @@ -78,13 +96,14 @@ class TestResourceLimitsServerNotices(unittest.TestCase): return_value=defer.succeed({"123": mock_event}) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event self._send_notice.assert_called_once() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): - """Test when user has blocked notice, but notice ought to be there (NOOP)""" + """ + Test when user has blocked notice, but notice ought to be there (NOOP) + """ self._rlsn._auth.check_auth_blocking = Mock( side_effect=ResourceLimitError(403, 'foo') ) @@ -95,52 +114,49 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._store.get_events = Mock( return_value=defer.succeed({"123": mock_event}) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_add_blocked_notice(self): - """Test when user does not have blocked notice, but should have one""" + """ + Test when user does not have blocked notice, but should have one + """ self._rlsn._auth.check_auth_blocking = Mock( side_effect=ResourceLimitError(403, 'foo') ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check contents, but 2 calls == set blocking event self.assertTrue(self._send_notice.call_count == 2) - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): - """Test when user does not have blocked notice, nor should they (NOOP)""" - + """ + Test when user does not have blocked notice, nor should they (NOOP) + """ self._rlsn._auth.check_auth_blocking = Mock() - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - @defer.inlineCallbacks def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): - - """Test when user is not part of the MAU cohort - this should not ever + """ + Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() -class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) +class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = self.hs.get_datastore() self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() @@ -165,26 +181,27 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): self.hs.config.admin_contact = "mailto:user@test.com" - @defer.inlineCallbacks def test_server_notice_only_sent_once(self): self.store.get_monthly_active_count = Mock(return_value=1000) self.store.user_last_seen_monthly_active = Mock(return_value=1000) # Call the function multiple times to ensure we only send the notice once - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) - yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Now lets get the last load of messages in the service notice room and # check that there is only one server notice - room_id = yield self.server_notices_manager.get_notice_room_for_user( - self.user_id + room_id = self.get_success( + self.server_notices_manager.get_notice_room_for_user(self.user_id) ) - token = yield self.event_source.get_current_token() - events, _ = yield self.store.get_recent_events_for_room( - room_id, limit=100, end_token=token.room_key + token = self.get_success(self.event_source.get_current_token()) + events, _ = self.get_success( + self.store.get_recent_events_for_room( + room_id, limit=100, end_token=token.room_key + ) ) count = 0 diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 0dde1ab2fe..fd3361404f 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py
@@ -16,7 +16,6 @@ from twisted.internet import defer from synapse.storage import UserDirectoryStore -from synapse.storage.roommember import ProfileInfo from tests import unittest from tests.utils import setup_test_homeserver @@ -34,17 +33,11 @@ class UserDirectoryStoreTestCase(unittest.TestCase): # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. - yield self.store.add_profiles_to_user_dir( - "!room:id", - { - ALICE: ProfileInfo(None, "alice"), - BOB: ProfileInfo(None, "bob"), - BOBBY: ProfileInfo(None, "bobby"), - }, - ) - yield self.store.add_users_to_public_room("!room:id", [ALICE, BOB]) - yield self.store.add_users_who_share_room( - "!room:id", False, ((ALICE, BOB), (BOB, ALICE)) + yield self.store.update_profile_in_user_dir(ALICE, "alice", None) + yield self.store.update_profile_in_user_dir(BOB, "bob", None) + yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None) + yield self.store.add_users_in_public_rooms( + "!room:id", (ALICE, BOB) ) @defer.inlineCallbacks diff --git a/tests/test_mau.py b/tests/test_mau.py
index 04f95c942f..00be1a8c21 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py
@@ -17,7 +17,7 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError @@ -36,7 +36,6 @@ class TestMauLimit(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.store = self.hs.get_datastore() diff --git a/tests/unittest.py b/tests/unittest.py
index fac254ff10..27403de908 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -262,6 +262,7 @@ class HomeserverTestCase(TestCase): access_token=None, request=SynapseRequest, shorthand=True, + federation_auth_origin=None, ): """ Create a SynapseRequest at the path using the method and containing the @@ -275,15 +276,18 @@ class HomeserverTestCase(TestCase): a dict. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. + federation_auth_origin (bytes|None): if set to not-None, we will add a fake + Authorization header pretenting to be the given server name. Returns: - A synapse.http.site.SynapseRequest. + Tuple[synapse.http.site.SynapseRequest, channel] """ if isinstance(content, dict): content = json.dumps(content).encode('utf8') return make_request( - self.reactor, method, path, content, access_token, request, shorthand + self.reactor, method, path, content, access_token, request, shorthand, + federation_auth_origin, ) def render(self, request): @@ -310,6 +314,9 @@ class HomeserverTestCase(TestCase): """ kwargs = dict(kwargs) kwargs.update(self._hs_args) + if "config" not in kwargs: + config = self.default_config() + kwargs["config"] = config hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() @@ -326,12 +333,21 @@ class HomeserverTestCase(TestCase): """ self.reactor.pump([by] * 100) - def get_success(self, d): + def get_success(self, d, by=0.0): if not isinstance(d, Deferred): return d - self.pump() + self.pump(by=by) return self.successResultOf(d) + def get_failure(self, d, exc): + """ + Run a Deferred and get a Failure from it. The failure must be of the type `exc`. + """ + if not isinstance(d, Deferred): + return d + self.pump() + return self.failureResultOf(d, exc) + def register_user(self, username, password, admin=False): """ Register a user. Requires the Admin API be registered. diff --git a/tests/utils.py b/tests/utils.py
index 2dfcb70a93..615b9f8cca 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018-2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,8 +29,8 @@ from twisted.internet import defer, reactor from synapse.api.constants import EventTypes, RoomVersions from synapse.api.errors import CodeMessageException, cs_error -from synapse.config.server import ServerConfig -from synapse.federation.transport import server +from synapse.config.homeserver import HomeServerConfig +from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer from synapse.server import HomeServer from synapse.storage import DataStore @@ -43,30 +44,32 @@ from synapse.util.logcontext import LoggingContext from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. +# +# When running under postgres, we first create a base database with the name +# POSTGRES_BASE_DB and update it to the current schema. Then, for each test case, we +# create another unique database, using the base database as a template. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False) -POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") +POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None) +POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None) +POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None) POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) +# the dbname we will connect to in order to create the base database. +POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" -def setupdb(): +def setupdb(): # If we're using PostgreSQL, set up the db once if USE_POSTGRES_FOR_TESTS: - pgconfig = { - "name": "psycopg2", - "args": { - "database": POSTGRES_BASE_DB, - "user": POSTGRES_USER, - "cp_min": 1, - "cp_max": 5, - }, - } - config = Mock() - config.password_providers = [] - config.database_config = pgconfig - db_engine = create_engine(pgconfig) - db_conn = db_engine.module.connect(user=POSTGRES_USER) + # create a PostgresEngine + db_engine = create_engine({"name": "psycopg2", "args": {}}) + + # connect to postgres to create the base database. + db_conn = db_engine.module.connect( + user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + ) db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) @@ -76,7 +79,10 @@ def setupdb(): # Set up in the db db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, user=POSTGRES_USER + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, ) cur = db_conn.cursor() _get_or_create_schema_state(cur, db_engine) @@ -86,7 +92,10 @@ def setupdb(): db_conn.close() def _cleanup(): - db_conn = db_engine.module.connect(user=POSTGRES_USER) + db_conn = db_engine.module.connect( + user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + ) db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) @@ -100,13 +109,25 @@ def default_config(name): """ Create a reasonable test config. """ - config = Mock() - config.signing_key = [MockKey()] + config_dict = { + "server_name": name, + "media_store_path": "media", + "uploads_path": "uploads", + + # the test signing key is just an arbitrary ed25519 key to keep the config + # parser happy + "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", + } + + config = HomeServerConfig() + config.parse_config_dict(config_dict) + + # TODO: move this stuff into config_dict or get rid of it config.event_cache_size = 1 config.enable_registration = True + config.enable_registration_captcha = False config.macaroon_secret_key = "not even a little secret" config.expire_access_token = False - config.server_name = name config.trusted_third_party_id_servers = [] config.room_invite_state_types = [] config.password_providers = [] @@ -139,9 +160,20 @@ def default_config(name): config.admin_contact = None config.rc_messages_per_second = 10000 config.rc_message_burst_count = 10000 + config.rc_registration.per_second = 10000 + config.rc_registration.burst_count = 10000 + config.rc_login_address.per_second = 10000 + config.rc_login_address.burst_count = 10000 + config.rc_login_account.per_second = 10000 + config.rc_login_account.burst_count = 10000 + config.rc_login_failed_attempts.per_second = 10000 + config.rc_login_failed_attempts.burst_count = 10000 config.saml2_enabled = False config.public_baseurl = None config.default_identity_server = None + config.key_refresh_interval = 24 * 60 * 60 * 1000 + config.old_signing_keys = {} + config.tls_fingerprints = [] config.use_frozen_dicts = False @@ -153,13 +185,6 @@ def default_config(name): # background, which upsets the test runner. config.update_user_directory = False - def is_threepid_reserved(threepid): - return ServerConfig.is_threepid_reserved( - config.mau_limits_reserved_threepids, threepid - ) - - config.is_threepid_reserved.side_effect = is_threepid_reserved - return config @@ -186,6 +211,9 @@ def setup_test_homeserver( Args: cleanup_func : The function used to register a cleanup routine for after the test. + + Calling this method directly is deprecated: you should instead derive from + HomeserverTestCase. """ if reactor is None: from twisted.internet import reactor @@ -203,7 +231,14 @@ def setup_test_homeserver( config.database_config = { "name": "psycopg2", - "args": {"database": test_db, "cp_min": 1, "cp_max": 5}, + "args": { + "database": test_db, + "host": POSTGRES_HOST, + "password": POSTGRES_PASSWORD, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, } else: config.database_config = { @@ -217,7 +252,10 @@ def setup_test_homeserver( # the template database we generate in setupdb() if datastore is None and isinstance(db_engine, PostgresEngine): db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, user=POSTGRES_USER + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, ) db_conn.autocommit = True cur = db_conn.cursor() @@ -240,7 +278,6 @@ def setup_test_homeserver( db_config=config.database_config, version_string="Synapse/tests", database_engine=db_engine, - room_list_handler=object(), tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, @@ -267,7 +304,10 @@ def setup_test_homeserver( # Drop the test database db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, user=POSTGRES_USER + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, ) db_conn.autocommit = True cur = db_conn.cursor() @@ -298,6 +338,8 @@ def setup_test_homeserver( cleanup_func(cleanup) hs.setup() + if homeserverToUse.__name__ == "TestHomeServer": + hs.setup_master() else: hs = homeserverToUse( name, @@ -306,7 +348,6 @@ def setup_test_homeserver( config=config, version_string="Synapse/tests", database_engine=db_engine, - room_list_handler=object(), tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, @@ -324,23 +365,27 @@ def setup_test_homeserver( fed = kargs.get("resource_for_federation", None) if fed: - server.register_servlets( - hs, - resource=fed, - authenticator=server.Authenticator(hs), - ratelimiter=FederationRateLimiter( - hs.get_clock(), - window_size=hs.config.federation_rc_window_size, - sleep_limit=hs.config.federation_rc_sleep_limit, - sleep_msec=hs.config.federation_rc_sleep_delay, - reject_limit=hs.config.federation_rc_reject_limit, - concurrent_requests=hs.config.federation_rc_concurrent, - ), - ) + register_federation_servlets(hs, fed) defer.returnValue(hs) +def register_federation_servlets(hs, resource): + federation_server.register_servlets( + hs, + resource=resource, + authenticator=federation_server.Authenticator(hs), + ratelimiter=FederationRateLimiter( + hs.get_clock(), + window_size=hs.config.federation_rc_window_size, + sleep_limit=hs.config.federation_rc_sleep_limit, + sleep_msec=hs.config.federation_rc_sleep_delay, + reject_limit=hs.config.federation_rc_reject_limit, + concurrent_requests=hs.config.federation_rc_concurrent, + ), + ) + + def get_mock_call_args(pattern_func, mock_func): """ Return the arguments the mock function was called with interpreted by the pattern functions argument list. @@ -457,6 +502,9 @@ class MockKey(object): def verify(self, message, sig): assert sig == b"\x9a\x87$" + def encode(self): + return b"<fake_encoded_key>" + class MockClock(object): now = 1000 @@ -486,7 +534,7 @@ class MockClock(object): return t def looping_call(self, function, interval): - self.loopers.append([function, interval / 1000., self.now]) + self.loopers.append([function, interval / 1000.0, self.now]) def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: @@ -522,7 +570,7 @@ class MockClock(object): looped[2] = self.now def advance_time_msec(self, ms): - self.advance_time(ms / 1000.) + self.advance_time(ms / 1000.0) def time_bound_deferred(self, d, *args, **kwargs): # We don't bother timing things out for now. @@ -631,7 +679,7 @@ def create_room(hs, room_id, creator_id): "sender": creator_id, "room_id": room_id, "content": {}, - } + }, ) event, context = yield event_creation_handler.create_new_client_event(builder)