diff options
Diffstat (limited to 'tests')
39 files changed, 3318 insertions, 749 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index f65a27e5f1..379e9c4ab1 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -471,6 +471,7 @@ class AuthTestCase(unittest.TestCase): def test_reserved_threepid(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 1 + self.store.get_monthly_active_count = lambda: defer.succeed(2) threepid = {'medium': 'email', 'address': 'reserved@server.com'} unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} self.hs.config.mau_limits_reserved_threepids = [threepid] diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 48b2d3d663..2a7044801a 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -60,7 +60,7 @@ class FilteringTestCase(unittest.TestCase): invalid_filters = [ {"boom": {}}, {"account_data": "Hello World"}, - {"event_fields": ["\\foo"]}, + {"event_fields": [r"\\foo"]}, {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, {"event_format": "other"}, {"room": {"not_rooms": ["#foo:pik-test"]}}, @@ -109,6 +109,16 @@ class FilteringTestCase(unittest.TestCase): "event_format": "client", "event_fields": ["type", "content", "sender"], }, + + # a single backslash should be permitted (though it is debatable whether + # it should be permitted before anything other than `.`, and what that + # actually means) + # + # (note that event_fields is implemented in + # synapse.events.utils.serialize_event, and so whether this actually works + # is tested elsewhere. We just want to check that it is allowed through the + # filter validation) + {"event_fields": [r"foo\.bar"]}, ] for filter in valid_filters: try: diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 76b5090fff..a83f567ebd 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -47,7 +47,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] self.resource = ( - site.resource.children["_matrix"].children["client"].children["r0"] + site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] ) request, channel = self.make_request("PUT", "presence/a/status") @@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] self.resource = ( - site.resource.children["_matrix"].children["client"].children["r0"] + site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] ) request, channel = self.make_request("PUT", "presence/a/status") diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index f88d28a19d..0c23068bcf 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -67,6 +67,6 @@ class ConfigGenerationTestCase(unittest.TestCase): with open(log_config_file) as f: config = f.read() # find the 'filename' line - matches = re.findall("^\s*filename:\s*(.*)$", config, re.M) + matches = re.findall(r"^\s*filename:\s*(.*)$", config, re.M) self.assertEqual(1, len(matches)) self.assertEqual(matches[0], expected) diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py new file mode 100644 index 0000000000..f37a17d618 --- /dev/null +++ b/tests/config/test_room_directory.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml + +from synapse.config.room_directory import RoomDirectoryConfig + +from tests import unittest + + +class RoomDirectoryConfigTestCase(unittest.TestCase): + def test_alias_creation_acl(self): + config = yaml.load(""" + alias_creation_rules: + - user_id: "*bob*" + alias: "*" + action: "deny" + - user_id: "*" + alias: "#unofficial_*" + action: "allow" + - user_id: "@foo*:example.com" + alias: "*" + action: "allow" + - user_id: "@gah:example.com" + alias: "#goo:example.com" + action: "allow" + """) + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.assertFalse(rd_config.is_alias_creation_allowed( + user_id="@bob:example.com", + alias="#test:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + alias="#unofficial_st:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@foobar:example.com", + alias="#test:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@gah:example.com", + alias="#goo:example.com", + )) + + self.assertFalse(rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + alias="#test:example.com", + )) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index ff217ca8b9..d0cc492deb 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -156,7 +156,7 @@ class SerializeEventTestCase(unittest.TestCase): room_id="!foo:bar", content={"key.with.dots": {}}, ), - ["content.key\.with\.dots"], + [r"content.key\.with\.dots"], ), {"content": {"key.with.dots": {}}}, ) @@ -172,7 +172,7 @@ class SerializeEventTestCase(unittest.TestCase): "nested.dot.key": {"leaf.key": 42, "not_me_either": 1}, }, ), - ["content.nested\.dot\.key.leaf\.key"], + [r"content.nested\.dot\.key.leaf\.key"], ), {"content": {"nested.dot.key": {"leaf.key": 42}}}, ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index ec7355688b..8ae6556c0a 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,7 +18,9 @@ from mock import Mock from twisted.internet import defer +from synapse.config.room_directory import RoomDirectoryConfig from synapse.handlers.directory import DirectoryHandler +from synapse.rest.client.v1 import directory, room from synapse.types import RoomAlias from tests import unittest @@ -102,3 +104,49 @@ class DirectoryTestCase(unittest.TestCase): ) self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + + +class TestCreateAliasACL(unittest.HomeserverTestCase): + user_id = "@test:test" + + servlets = [directory.register_servlets, room.register_servlets] + + def prepare(self, hs, reactor, clock): + # We cheekily override the config to add custom alias creation rules + config = {} + config["alias_creation_rules"] = [ + { + "user_id": "*", + "alias": "#unofficial_*", + "action": "allow", + } + ] + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed + + return hs + + def test_denied(self): + room_id = self.helper.create_room_as(self.user_id) + + request, channel = self.make_request( + "PUT", + b"directory/room/%23test%3Atest", + ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) + + def test_allowed(self): + room_id = self.helper.create_room_as(self.user_id) + + request, channel = self.make_request( + "PUT", + b"directory/room/%23unofficial_test%3Atest", + ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py new file mode 100644 index 0000000000..9e08eac0a5 --- /dev/null +++ b/tests/handlers/test_e2e_room_keys.py @@ -0,0 +1,397 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import mock + +from twisted.internet import defer + +import synapse.api.errors +import synapse.handlers.e2e_room_keys +import synapse.storage +from synapse.api import errors + +from tests import unittest, utils + +# sample room_key data for use in the tests +room_keys = { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": False, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } +} + + +class E2eRoomKeysHandlerTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) + self.hs = None # type: synapse.server.HomeServer + self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield utils.setup_test_homeserver( + self.addCleanup, + handlers=None, + replication_layer=mock.Mock(), + ) + self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) + self.local_user = "@boris:" + self.hs.hostname + + @defer.inlineCallbacks + def test_get_missing_current_version_info(self): + """Check that we get a 404 if we ask for info about the current version + if there is no version. + """ + res = None + try: + yield self.handler.get_version_info(self.local_user) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_get_missing_version_info(self): + """Check that we get a 404 if we ask for info about a specific version + if it doesn't exist. + """ + res = None + try: + yield self.handler.get_version_info(self.local_user, "bogus_version") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_create_version(self): + """Check that we can create and then retrieve versions. + """ + res = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(res, "1") + + # check we can retrieve it as the current version + res = yield self.handler.get_version_info(self.local_user) + self.assertDictEqual(res, { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + + # check we can retrieve it as a specific version + res = yield self.handler.get_version_info(self.local_user, "1") + self.assertDictEqual(res, { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + + # upload a new one... + res = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }) + self.assertEqual(res, "2") + + # check we can retrieve it as the current version + res = yield self.handler.get_version_info(self.local_user) + self.assertDictEqual(res, { + "version": "2", + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }) + + @defer.inlineCallbacks + def test_delete_missing_version(self): + """Check that we get a 404 on deleting nonexistent versions + """ + res = None + try: + yield self.handler.delete_version(self.local_user, "1") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_delete_missing_current_version(self): + """Check that we get a 404 on deleting nonexistent current version + """ + res = None + try: + yield self.handler.delete_version(self.local_user) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_delete_version(self): + """Check that we can create and then delete versions. + """ + res = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(res, "1") + + # check we can delete it + yield self.handler.delete_version(self.local_user, "1") + + # check that it's gone + res = None + try: + yield self.handler.get_version_info(self.local_user, "1") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_get_missing_room_keys(self): + """Check that we get a 404 on querying missing room_keys + """ + res = None + try: + yield self.handler.get_room_keys(self.local_user, "bogus_version") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # check we also get a 404 even if the version is valid + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + res = None + try: + yield self.handler.get_room_keys(self.local_user, version) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # TODO: test the locking semantics when uploading room_keys, + # although this is probably best done in sytest + + @defer.inlineCallbacks + def test_upload_room_keys_no_versions(self): + """Check that we get a 404 on uploading keys when no versions are defined + """ + res = None + try: + yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_upload_room_keys_bogus_version(self): + """Check that we get a 404 on uploading keys when an nonexistent version + is specified + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + res = None + try: + yield self.handler.upload_room_keys( + self.local_user, "bogus_version", room_keys + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_upload_room_keys_wrong_version(self): + """Check that we get a 403 on uploading keys for an old version + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }) + self.assertEqual(version, "2") + + res = None + try: + yield self.handler.upload_room_keys(self.local_user, "1", room_keys) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 403) + + @defer.inlineCallbacks + def test_upload_room_keys_insert(self): + """Check that we can insert and retrieve keys for a session + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertDictEqual(res, room_keys) + + # check getting room_keys for a given room + res = yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org" + ) + self.assertDictEqual(res, room_keys) + + # check getting room_keys for a given session_id + res = yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + self.assertDictEqual(res, room_keys) + + @defer.inlineCallbacks + def test_upload_room_keys_merge(self): + """Check that we can upload a new room_key for an existing session and + have it correctly merged""" + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + + new_room_keys = copy.deepcopy(room_keys) + new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33'] + + # test that increasing the message_index doesn't replace the existing session + new_room_key['first_message_index'] = 2 + new_room_key['session_data'] = 'new' + yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertEqual( + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + "SSBBTSBBIEZJU0gK" + ) + + # test that marking the session as verified however /does/ replace it + new_room_key['is_verified'] = True + yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertEqual( + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + "new" + ) + + # test that a session with a higher forwarded_count doesn't replace one + # with a lower forwarding count + new_room_key['forwarded_count'] = 2 + new_room_key['session_data'] = 'other' + yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertEqual( + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + "new" + ) + + # TODO: check edge cases as well as the common variations here + + @defer.inlineCallbacks + def test_delete_room_keys(self): + """Check that we can insert and delete keys for a session + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + # check for bulk-delete + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield self.handler.delete_room_keys(self.local_user, version) + res = None + try: + yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # check for bulk-delete per room + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield self.handler.delete_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + ) + res = None + try: + yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # check for bulk-delete per session + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield self.handler.delete_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + res = None + try: + yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 7b4ade3dfb..3e9a190727 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api.errors import ResourceLimitError from synapse.handlers.register import RegistrationHandler -from synapse.types import UserID, create_requester +from synapse.types import RoomAlias, UserID, create_requester from tests.utils import setup_test_homeserver @@ -41,30 +41,27 @@ class RegistrationTestCase(unittest.TestCase): self.mock_captcha_client = Mock() self.hs = yield setup_test_homeserver( self.addCleanup, - handlers=None, - http_client=None, expire_access_token=True, - profile_handler=Mock(), ) self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret') ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) - self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.store = self.hs.get_datastore() self.hs.config.max_mau_value = 50 self.lots_of_users = 100 self.small_number_of_users = 1 + self.requester = create_requester("@requester:test") + @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): - local_part = "someone" - display_name = "someone" - user_id = "@someone:test" - requester = create_requester("@as:test") + frank = UserID.from_string("@frank:test") + user_id = frank.to_string() + requester = create_requester(user_id) result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name + requester, frank.localpart, "Frankie" ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -78,12 +75,11 @@ class RegistrationTestCase(unittest.TestCase): token="jkv;g498752-43gj['eamb!-5", password_hash=None, ) - local_part = "frank" - display_name = "Frank" - user_id = "@frank:test" - requester = create_requester("@as:test") + local_part = frank.localpart + user_id = frank.to_string() + requester = create_requester(user_id) result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name + requester, local_part, None ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -92,7 +88,7 @@ class RegistrationTestCase(unittest.TestCase): def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.handler.get_or_create_user("requester", 'a', "display_name") + yield self.handler.get_or_create_user(self.requester, 'a', "display_name") @defer.inlineCallbacks def test_get_or_create_user_mau_not_blocked(self): @@ -101,7 +97,7 @@ class RegistrationTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - yield self.handler.get_or_create_user("@user:server", 'c', "User") + yield self.handler.get_or_create_user(self.requester, 'c', "User") @defer.inlineCallbacks def test_get_or_create_user_mau_blocked(self): @@ -110,13 +106,13 @@ class RegistrationTestCase(unittest.TestCase): return_value=defer.succeed(self.lots_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user("requester", 'b', "display_name") + yield self.handler.get_or_create_user(self.requester, 'b', "display_name") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user("requester", 'b', "display_name") + yield self.handler.get_or_create_user(self.requester, 'b', "display_name") @defer.inlineCallbacks def test_register_mau_blocked(self): @@ -147,3 +143,44 @@ class RegistrationTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): yield self.handler.register_saml2(localpart="local_part") + + @defer.inlineCallbacks + def test_auto_create_auto_join_rooms(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + res = yield self.handler.register(localpart='jeff') + rooms = yield self.store.get_rooms_for_user(res[0]) + + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = yield directory_handler.get_association(room_alias) + + self.assertTrue(room_id['room_id'] in rooms) + self.assertEqual(len(rooms), 1) + + @defer.inlineCallbacks + def test_auto_create_auto_join_rooms_with_no_rooms(self): + self.hs.config.auto_join_rooms = [] + frank = UserID.from_string("@frank:test") + res = yield self.handler.register(frank.localpart) + self.assertEqual(res[0], frank.to_string()) + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) + + @defer.inlineCallbacks + def test_auto_create_auto_join_where_room_is_another_domain(self): + self.hs.config.auto_join_rooms = ["#room:another"] + frank = UserID.from_string("@frank:test") + res = yield self.handler.register(frank.localpart) + self.assertEqual(res[0], frank.to_string()) + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) + + @defer.inlineCallbacks + def test_auto_create_auto_join_where_auto_create_is_false(self): + self.hs.config.autocreate_auto_join_rooms = False + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + res = yield self.handler.register(localpart='jeff') + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py new file mode 100644 index 0000000000..61eebb6985 --- /dev/null +++ b/tests/handlers/test_roomlist.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.handlers.room_list import RoomListNextBatch + +import tests.unittest +import tests.utils + + +class RoomListTestCase(tests.unittest.TestCase): + """ Tests RoomList's RoomListNextBatch. """ + + def setUp(self): + pass + + def test_check_read_batch_tokens(self): + batch_token = RoomListNextBatch( + stream_ordering="abcdef", + public_room_stream_id="123", + current_limit=20, + direction_is_forward=True, + ).to_token() + next_batch = RoomListNextBatch.from_token(batch_token) + self.assertEquals(next_batch.stream_ordering, "abcdef") + self.assertEquals(next_batch.public_room_stream_id, "123") + self.assertEquals(next_batch.current_limit, 20) + self.assertEquals(next_batch.direction_is_forward, True) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index c2d951b45f..36e136cded 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -43,9 +43,7 @@ def _expect_edu_transaction(edu_type, content, origin="test"): def _make_edu_transaction_json(edu_type, content): - return json.dumps(_expect_edu_transaction(edu_type, content)).encode( - 'utf8' - ) + return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8') class TypingNotificationsTestCase(unittest.TestCase): diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py new file mode 100644 index 0000000000..f3cb1423f0 --- /dev/null +++ b/tests/http/test_fedclient.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet.defer import TimeoutError +from twisted.internet.error import ConnectingCancelledError, DNSLookupError +from twisted.web.client import ResponseNeverReceived +from twisted.web.http import HTTPChannel + +from synapse.http.matrixfederationclient import ( + MatrixFederationHttpClient, + MatrixFederationRequest, +) + +from tests.server import FakeTransport +from tests.unittest import HomeserverTestCase + + +class FederationClientTests(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver(reactor=reactor, clock=clock) + hs.tls_client_options_factory = None + return hs + + def prepare(self, reactor, clock, homeserver): + + self.cl = MatrixFederationHttpClient(self.hs) + self.reactor.lookups["testserv"] = "1.2.3.4" + + def test_dns_error(self): + """ + If the DNS raising returns an error, it will bubble up. + """ + d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + self.pump() + + f = self.failureResultOf(d) + self.assertIsInstance(f.value, DNSLookupError) + + def test_client_never_connect(self): + """ + If the HTTP request is not connected and is timed out, it'll give a + ConnectingCancelledError or TimeoutError. + """ + d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + + self.pump() + + # Nothing happened yet + self.assertFalse(d.called) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][1], 8008) + + # Deferred is still without a result + self.assertFalse(d.called) + + # Push by enough to time it out + self.reactor.advance(10.5) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, (ConnectingCancelledError, TimeoutError)) + + def test_client_connect_no_response(self): + """ + If the HTTP request is connected, but gets no response before being + timed out, it'll give a ResponseNeverReceived. + """ + d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + + self.pump() + + # Nothing happened yet + self.assertFalse(d.called) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][1], 8008) + + conn = Mock() + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred is still without a result + self.assertFalse(d.called) + + # Push by enough to time it out + self.reactor.advance(10.5) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, ResponseNeverReceived) + + def test_client_gets_headers(self): + """ + Once the client gets the headers, _request returns successfully. + """ + request = MatrixFederationRequest( + method="GET", + destination="testserv:8008", + path="foo/bar", + ) + d = self.cl._send_request(request, timeout=10000) + + self.pump() + + conn = Mock() + clients = self.reactor.tcpClients + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred does not have a result + self.assertFalse(d.called) + + # Send it the HTTP response + client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") + + # We should get a successful response + r = self.successResultOf(d) + self.assertEqual(r.code, 200) + + def test_client_headers_no_body(self): + """ + If the HTTP request is connected, but gets no response before being + timed out, it'll give a ResponseNeverReceived. + """ + d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + + self.pump() + + conn = Mock() + clients = self.reactor.tcpClients + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred does not have a result + self.assertFalse(d.called) + + # Send it the HTTP response + client.dataReceived( + (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n") + ) + + # Push by enough to time it out + self.reactor.advance(10.5) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, TimeoutError) + + def test_client_sends_body(self): + self.cl.post_json( + "testserv:8008", "foo/bar", timeout=10000, + data={"a": "b"} + ) + + self.pump() + + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + client = clients[0][2].buildProtocol(None) + server = HTTPChannel() + + client.makeConnection(FakeTransport(server, self.reactor)) + server.makeConnection(FakeTransport(client, self.reactor)) + + self.pump(0.1) + + self.assertEqual(len(server.requests), 1) + request = server.requests[0] + content = request.content.read() + self.assertEqual(content, b'{"a":"b"}') diff --git a/tests/push/__init__.py b/tests/push/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/push/__init__.py diff --git a/tests/push/test_email.py b/tests/push/test_email.py new file mode 100644 index 0000000000..50ee6910d1 --- /dev/null +++ b/tests/push/test_email.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pkg_resources + +from twisted.internet.defer import Deferred + +from synapse.rest.client.v1 import admin, login, room + +from tests.unittest import HomeserverTestCase + +try: + from synapse.push.mailer import load_jinja2_templates +except Exception: + load_jinja2_templates = None + + +class EmailPusherTests(HomeserverTestCase): + + skip = "No Jinja installed" if not load_jinja2_templates else None + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def make_homeserver(self, reactor, clock): + + # List[Tuple[Deferred, args, kwargs]] + self.email_attempts = [] + + def sendmail(*args, **kwargs): + d = Deferred() + self.email_attempts.append((d, args, kwargs)) + return d + + config = self.default_config() + config.email_enable_notifs = True + config.start_pushers = True + + config.email_template_dir = os.path.abspath( + pkg_resources.resource_filename('synapse', 'res/templates') + ) + config.email_notif_template_html = "notif_mail.html" + config.email_notif_template_text = "notif_mail.txt" + config.email_smtp_host = "127.0.0.1" + config.email_smtp_port = 20 + config.require_transport_security = False + config.email_smtp_user = None + config.email_app_name = "Matrix" + config.email_notif_from = "test@example.com" + + hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + + return hs + + def test_sends_email(self): + + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastore().get_user_by_access_token(access_token) + ) + token_id = user_tuple["token_id"] + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name="a@example.com", + pushkey="a@example.com", + lang=None, + data={}, + ) + ) + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Invite the other person + self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=other_access_token) + self.helper.send(room, body="There!", tok=other_access_token) + + # Get the stream ordering before it gets sent + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + last_stream_ordering = pushers[0]["last_stream_ordering"] + + # Advance time a bit, so the pusher will register something has happened + self.pump(100) + + # It hasn't succeeded yet, so the stream ordering shouldn't have moved + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) + + # One email was attempted to be sent + self.assertEqual(len(self.email_attempts), 1) + + # Make the email succeed + self.email_attempts[0][0].callback(True) + self.pump() + + # One email was attempted to be sent + self.assertEqual(len(self.email_attempts), 1) + + # The stream ordering has increased + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + ) + self.assertEqual(len(pushers), 1) + self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 089cecfbee..9e9fbbfe93 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -15,8 +15,6 @@ from mock import Mock, NonCallableMock -import attr - from synapse.replication.tcp.client import ( ReplicationClientFactory, ReplicationClientHandler, @@ -24,6 +22,7 @@ from synapse.replication.tcp.client import ( from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from tests import unittest +from tests.server import FakeTransport class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): @@ -56,36 +55,8 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): server = server_factory.buildProtocol(None) client = client_factory.buildProtocol(None) - @attr.s - class FakeTransport(object): - - other = attr.ib() - disconnecting = False - buffer = attr.ib(default=b'') - - def registerProducer(self, producer, streaming): - - self.producer = producer - - def _produce(): - self.producer.resumeProducing() - reactor.callLater(0.1, _produce) - - reactor.callLater(0.0, _produce) - - def write(self, byt): - self.buffer = self.buffer + byt - - if getattr(self.other, "transport") is not None: - self.other.dataReceived(self.buffer) - self.buffer = b"" - - def writeSequence(self, seq): - for x in seq: - self.write(x) - - client.makeConnection(FakeTransport(server)) - server.makeConnection(FakeTransport(client)) + client.makeConnection(FakeTransport(server, reactor)) + server.makeConnection(FakeTransport(client, reactor)) def replicate(self): """Tell the master side of replication that something has happened, and then diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index db44d33c68..41be5d5a1a 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from canonicaljson import encode_canonical_json + from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext from synapse.replication.slave.storage.events import SlavedEventStore @@ -26,7 +28,9 @@ ROOM_ID = "!room:blue" def dict_equals(self, other): - return self.__dict__ == other.__dict__ + me = encode_canonical_json(self._event_dict) + them = encode_canonical_json(other._event_dict) + return me == them def patch__eq__(cls): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 9fe0760496..a824be9a62 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -22,39 +22,24 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer -import synapse.rest.client.v1.room from synapse.api.constants import Membership -from synapse.http.server import JsonResource -from synapse.types import UserID -from synapse.util import Clock +from synapse.rest.client.v1 import admin, login, room from tests import unittest -from tests.server import ( - ThreadedMemoryReactorClock, - make_request, - render, - setup_test_homeserver, -) - -from .utils import RestHelper PATH_PREFIX = b"/_matrix/client/api/v1" -class RoomBase(unittest.TestCase): +class RoomBase(unittest.HomeserverTestCase): rmcreator_id = None - def setUp(self): + servlets = [room.register_servlets, room.register_deprecated_servlets] - self.clock = ThreadedMemoryReactorClock() - self.hs_clock = Clock(self.clock) + def make_homeserver(self, reactor, clock): - self.hs = setup_test_homeserver( - self.addCleanup, + self.hs = self.setup_test_homeserver( "red", http_client=None, - clock=self.hs_clock, - reactor=self.clock, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) @@ -63,42 +48,21 @@ class RoomBase(unittest.TestCase): self.hs.get_federation_handler = Mock(return_value=Mock()) - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } - - def get_user_by_req(request, allow_guest=False, rights="access"): - return synapse.types.create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None - ) - - self.hs.get_auth().get_user_by_req = get_user_by_req - self.hs.get_auth().get_user_by_access_token = get_user_by_access_token - self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234") - def _insert_client_ip(*args, **kwargs): return defer.succeed(None) self.hs.get_datastore().insert_client_ip = _insert_client_ip - self.resource = JsonResource(self.hs) - synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) - synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource) - self.helper = RestHelper(self.hs, self.resource, self.user_id) + return self.hs class RoomPermissionsTestCase(RoomBase): """ Tests room permissions. """ - user_id = b"@sid1:red" - rmcreator_id = b"@notme:red" - - def setUp(self): + user_id = "@sid1:red" + rmcreator_id = "@notme:red" - super(RoomPermissionsTestCase, self).setUp() + def prepare(self, reactor, clock, hs): self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id @@ -114,22 +78,20 @@ class RoomPermissionsTestCase(RoomBase): self.created_rmid_msg_path = ( "rooms/%s/send/m.room.message/a1" % (self.created_rmid) ).encode('ascii') - request, channel = make_request( - b"PUT", - self.created_rmid_msg_path, - b'{"msgtype":"m.text","body":"test msg"}', + request, channel = self.make_request( + "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - render(request, self.resource, self.clock) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.render(request) + self.assertEquals(200, channel.code, channel.result) # set topic for public room - request, channel = make_request( - b"PUT", + request, channel = self.make_request( + "PUT", ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), b'{"topic":"Public Room Topic"}', ) - render(request, self.resource, self.clock) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.render(request) + self.assertEquals(200, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -140,128 +102,128 @@ class RoomPermissionsTestCase(RoomBase): seq = iter(range(100)) def send_msg_path(): - return b"/rooms/%s/send/m.room.message/mid%s" % ( + return "/rooms/%s/send/m.room.message/mid%s" % ( self.created_rmid, - str(next(seq)).encode('ascii'), + str(next(seq)), ) # send message in uncreated room, expect 403 - request, channel = make_request( - b"PUT", - b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + request, channel = self.make_request( + "PUT", + "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): topic_content = b'{"topic":"My Topic Name"}' - topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid + topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid # set/get topic in uncreated room, expect 403 - request, channel = make_request( - b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content + request, channel = self.make_request( + "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - request, channel = make_request( - b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + request, channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assert_dict(json.loads(topic_content), channel.json_body) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 - request, channel = make_request( - b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid + request, channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 - request, channel = make_request( - b"PUT", - b"/rooms/%s/state/m.room.topic" % self.created_public_rmid, + request, channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: - path = b"/rooms/%s/state/m.room.member/%s" % (room, member) - request, channel = make_request(b"GET", path) - render(request, self.resource, self.clock) - self.assertEquals(expect_code, int(channel.result["code"])) + path = "/rooms/%s/state/m.room.member/%s" % (room, member) + request, channel = self.make_request("GET", path) + self.render(request) + self.assertEquals(expect_code, channel.code) def test_membership_basic_room_perms(self): # === room does not exist === @@ -428,217 +390,211 @@ class RoomPermissionsTestCase(RoomBase): class RoomsMemberListTestCase(RoomBase): """ Tests /rooms/$room_id/members/list REST events.""" - user_id = b"@sid1:red" + user_id = "@sid1:red" def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): - request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members") - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): - room_id = self.helper.create_room_as(b"@some_other_guy:red") - request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + room_id = self.helper.create_room_as("@some_other_guy:red") + request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): - room_creator = b"@some_other_guy:red" + room_creator = "@some_other_guy:red" room_id = self.helper.create_room_as(room_creator) - room_path = b"/rooms/%s/members" % room_id + room_path = "/rooms/%s/members" % room_id self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. - request, channel = make_request(b"GET", room_path) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", room_path) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined - request, channel = make_request(b"GET", room_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", room_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left - request, channel = make_request(b"GET", room_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", room_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) class RoomsCreateTestCase(RoomBase): """ Tests /rooms and /rooms/$room_id REST events. """ - user_id = b"@sid1:red" + user_id = "@sid1:red" def test_post_room_no_keys(self): # POST with no config keys, expect new room id - request, channel = make_request(b"POST", b"/createRoom", b"{}") + request, channel = self.make_request("POST", "/createRoom", "{}") - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), channel.result) + self.render(request) + self.assertEquals(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id - request, channel = make_request( - b"POST", b"/createRoom", b'{"visibility":"private"}' + request, channel = self.make_request( + "POST", "/createRoom", b'{"visibility":"private"}' ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id - request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}') - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + request, channel = self.make_request( + "POST", "/createRoom", b'{"custom":"stuff"}' + ) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): # POST with custom + known config keys, expect new room id - request, channel = make_request( - b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}' + request, channel = self.make_request( + "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 - request, channel = make_request(b"POST", b"/createRoom", b'{"visibili') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"])) + request, channel = self.make_request("POST", "/createRoom", b'{"visibili') + self.render(request) + self.assertEquals(400, channel.code) - request, channel = make_request(b"POST", b"/createRoom", b'["hello"]') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"])) + request, channel = self.make_request("POST", "/createRoom", b'["hello"]') + self.render(request) + self.assertEquals(400, channel.code) class RoomTopicTestCase(RoomBase): """ Tests /rooms/$room_id/topic REST events. """ - user_id = b"@sid1:red" - - def setUp(self): - - super(RoomTopicTestCase, self).setUp() + user_id = "@sid1:red" + def prepare(self, reactor, clock, hs): # create the room self.room_id = self.helper.create_room_as(self.user_id) - self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,) + self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) def test_invalid_puts(self): # missing keys or invalid json - request, channel = make_request(b"PUT", self.path, '{}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '{}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, '{"nao') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '{"nao') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request( - b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = self.make_request( + "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, 'text only') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, 'text only') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, '') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' - request, channel = make_request(b"PUT", self.path, content) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, content) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there - request, channel = make_request(b"GET", self.path) - render(request, self.resource, self.clock) - self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", self.path) + self.render(request) + self.assertEquals(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' - request, channel = make_request(b"PUT", self.path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = make_request(b"GET", self.path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", self.path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - request, channel = make_request(b"PUT", self.path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = make_request(b"GET", self.path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", self.path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) class RoomMemberStateTestCase(RoomBase): """ Tests /rooms/$room_id/members/$user_id/state REST events. """ - user_id = b"@sid1:red" - - def setUp(self): + user_id = "@sid1:red" - super(RoomMemberStateTestCase, self).setUp() + def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) - def tearDown(self): - pass - def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - request, channel = make_request(b"PUT", path, '{}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '{}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"_name":"bob"}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '{"_name":"bo"}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"nao') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '{"nao') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request( - b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]' + request, channel = self.make_request( + "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' ) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, 'text only') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, 'text only') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -646,9 +602,9 @@ class RoomMemberStateTestCase(RoomBase): Membership.JOIN, Membership.LEAVE, ) - request, channel = make_request(b"PUT", path, content.encode('ascii')) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content.encode('ascii')) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( @@ -658,13 +614,13 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - request, channel = make_request(b"PUT", path, content.encode('ascii')) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content.encode('ascii')) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"GET", path, None) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", path, None) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEquals(expected_response, channel.json_body) @@ -678,13 +634,13 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"GET", path, None) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", path, None) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self): @@ -699,13 +655,13 @@ class RoomMemberStateTestCase(RoomBase): Membership.INVITE, "Join us!", ) - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"GET", path, None) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", path, None) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -714,60 +670,58 @@ class RoomMessagesTestCase(RoomBase): user_id = "@sid1:red" - def setUp(self): - super(RoomMessagesTestCase, self).setUp() - + def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - request, channel = make_request(b"PUT", path, '{}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'{}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"_name":"bob"}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'{"_name":"bo"}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"nao') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'{"nao') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request( - b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = self.make_request( + "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' ) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, 'text only') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'text only') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) - content = '{"body":"test","msgtype":{"type":"a"}}' - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + content = b'{"body":"test","msgtype":{"type":"a"}}' + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) # custom message types - content = '{"body":"test","msgtype":"test.custom.text"}' - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + content = b'{"body":"test","msgtype":"test.custom.text"}' + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) - content = '{"body":"test2","msgtype":"m.text"}' - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + content = b'{"body":"test2","msgtype":"m.text"}' + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) class RoomInitialSyncTestCase(RoomBase): @@ -775,16 +729,16 @@ class RoomInitialSyncTestCase(RoomBase): user_id = "@sid1:red" - def setUp(self): - super(RoomInitialSyncTestCase, self).setUp() - + def prepare(self, reactor, clock, hs): # create the room self.room_id = self.helper.create_room_as(self.user_id) def test_initial_sync(self): - request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + request, channel = self.make_request( + "GET", "/rooms/%s/initialSync" % self.room_id + ) + self.render(request) + self.assertEquals(200, channel.code) self.assertEquals(self.room_id, channel.json_body["room_id"]) self.assertEquals("join", channel.json_body["membership"]) @@ -819,17 +773,16 @@ class RoomMessageListTestCase(RoomBase): user_id = "@sid1:red" - def setUp(self): - super(RoomMessageListTestCase, self).setUp() + def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self): token = "t1-0_0_0_0_0_0_0_0_0" - request, channel = make_request( - b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + request, channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body['start']) self.assertTrue("chunk" in channel.json_body) @@ -837,12 +790,116 @@ class RoomMessageListTestCase(RoomBase): def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0_0_0_0_0" - request, channel = make_request( - b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + request, channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body['start']) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + + +class RoomSearchTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def prepare(self, reactor, clock, hs): + + # Register the user who does the searching + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + # Register the user who sends the message + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + # Create a room + self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + + # Invite the other person + self.helper.invite( + room=self.room, + src=self.user_id, + tok=self.access_token, + targ=self.other_user_id, + ) + + # The other user joins + self.helper.join( + room=self.room, user=self.other_user_id, tok=self.other_access_token + ) + + def test_finds_message(self): + """ + The search functionality will search for content in messages if asked to + do so. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + request, channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": {"keys": ["content.body"], "search_term": "Hi"} + } + }, + ) + self.render(request) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # No context was requested, so we should get none. + self.assertEqual(results["results"][0]["context"], {}) + + def test_include_context(self): + """ + When event_context includes include_profile, profile information will be + included in the search response. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + request, channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": { + "keys": ["content.body"], + "search_term": "Hi", + "event_context": {"include_profile": True}, + } + } + }, + ) + self.render(request) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # We should get context info, like the two users, and the display names. + context = results["results"][0]["context"] + self.assertEqual(len(context["profile_info"].keys()), 2) + self.assertEqual( + context["profile_info"][self.other_user_id]["displayname"], "otheruser" + ) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 560b1fba96..4c30c5f258 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -62,12 +62,6 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( set( - [ - "next_batch", - "rooms", - "account_data", - "to_device", - "device_lists", - ] + ["next_batch", "rooms", "account_data", "to_device", "device_lists"] ).issubset(set(channel.json_body.keys())) ) diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/scripts/__init__.py diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py new file mode 100644 index 0000000000..6f56893f5e --- /dev/null +++ b/tests/scripts/test_new_matrix_user.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from synapse._scripts.register_new_matrix_user import request_registration + +from tests.unittest import TestCase + + +class RegisterTestCase(TestCase): + def test_success(self): + """ + The script will fetch a nonce, and then generate a MAC with it, and then + post that MAC. + """ + + def get(url, verify=None): + r = Mock() + r.status_code = 200 + r.json = lambda: {"nonce": "a"} + return r + + def post(url, json=None, verify=None): + # Make sure we are sent the correct info + self.assertEqual(json["username"], "user") + self.assertEqual(json["password"], "pass") + self.assertEqual(json["nonce"], "a") + # We want a 40-char hex MAC + self.assertEqual(len(json["mac"]), 40) + + r = Mock() + r.status_code = 200 + return r + + requests = Mock() + requests.get = get + requests.post = post + + # The fake stdout will be written here + out = [] + err_code = [] + + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + requests=requests, + _print=out.append, + exit=err_code.append, + ) + + # We should get the success message making sure everything is OK. + self.assertIn("Success!", out) + + # sys.exit shouldn't have been called. + self.assertEqual(err_code, []) + + def test_failure_nonce(self): + """ + If the script fails to fetch a nonce, it throws an error and quits. + """ + + def get(url, verify=None): + r = Mock() + r.status_code = 404 + r.reason = "Not Found" + r.json = lambda: {"not": "error"} + return r + + requests = Mock() + requests.get = get + + # The fake stdout will be written here + out = [] + err_code = [] + + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + requests=requests, + _print=out.append, + exit=err_code.append, + ) + + # Exit was called + self.assertEqual(err_code, [1]) + + # We got an error message + self.assertIn("ERROR! Received 404 Not Found", out) + self.assertNotIn("Success!", out) + + def test_failure_post(self): + """ + The script will fetch a nonce, and then if the final POST fails, will + report an error and quit. + """ + + def get(url, verify=None): + r = Mock() + r.status_code = 200 + r.json = lambda: {"nonce": "a"} + return r + + def post(url, json=None, verify=None): + # Make sure we are sent the correct info + self.assertEqual(json["username"], "user") + self.assertEqual(json["password"], "pass") + self.assertEqual(json["nonce"], "a") + # We want a 40-char hex MAC + self.assertEqual(len(json["mac"]), 40) + + r = Mock() + # Then 500 because we're jerks + r.status_code = 500 + r.reason = "Broken" + return r + + requests = Mock() + requests.get = get + requests.post = post + + # The fake stdout will be written here + out = [] + err_code = [] + + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + requests=requests, + _print=out.append, + exit=err_code.append, + ) + + # Exit was called + self.assertEqual(err_code, [1]) + + # We got an error message + self.assertIn("ERROR! Received 500 Broken", out) + self.assertNotIn("Success!", out) diff --git a/tests/server.py b/tests/server.py index 615bba1b59..819c854448 100644 --- a/tests/server.py +++ b/tests/server.py @@ -4,9 +4,14 @@ from io import BytesIO from six import text_type import attr +from zope.interface import implementer -from twisted.internet import address, threads +from twisted.internet import address, threads, udp +from twisted.internet._resolver import HostResolution +from twisted.internet.address import IPv4Address from twisted.internet.defer import Deferred +from twisted.internet.error import DNSLookupError +from twisted.internet.interfaces import IReactorPluggableNameResolver from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactorClock @@ -65,7 +70,7 @@ class FakeChannel(object): def getPeer(self): # We give an address so that getClientIP returns a non null entry, # causing us to record the MAU - return address.IPv4Address(b"TCP", "127.0.0.1", 3423) + return address.IPv4Address("TCP", "127.0.0.1", 3423) def getHost(self): return None @@ -93,7 +98,7 @@ class FakeSite: return FakeLogger() -def make_request(method, path, content=b"", access_token=None): +def make_request(method, path, content=b"", access_token=None, request=SynapseRequest): """ Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. @@ -115,14 +120,18 @@ def make_request(method, path, content=b"", access_token=None): site = FakeSite() channel = FakeChannel() - req = SynapseRequest(site, channel) + req = request(site, channel) req.process = lambda: b"" req.content = BytesIO(content) if access_token: - req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) + req.requestHeaders.addRawHeader( + b"Authorization", b"Bearer " + access_token.encode('ascii') + ) + + if content: + req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") - req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1") req.requestReceived(method, path, b"1.1") return req, channel @@ -154,11 +163,46 @@ def render(request, resource, clock): wait_until_result(clock, request) +@implementer(IReactorPluggableNameResolver) class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. """ + def __init__(self): + self._udp = [] + self.lookups = {} + + class Resolver(object): + def resolveHostName( + _self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics='TCP', + ): + + resolution = HostResolution(hostName) + resolutionReceiver.resolutionBegan(resolution) + if hostName not in self.lookups: + raise DNSLookupError("OH NO") + + resolutionReceiver.addressResolved( + IPv4Address('TCP', self.lookups[hostName], portNumber) + ) + resolutionReceiver.resolutionComplete() + return resolution + + self.nameResolver = Resolver() + super(ThreadedMemoryReactorClock, self).__init__() + + def listenUDP(self, port, protocol, interface='', maxPacketSize=8196): + p = udp.Port(port, protocol, interface, maxPacketSize, self) + p.startListening() + self._udp.append(p) + return p + def callFromThread(self, callback, *args, **kwargs): """ Make the callback fire in the next reactor iteration. @@ -240,3 +284,84 @@ def get_clock(): clock = ThreadedMemoryReactorClock() hs_clock = Clock(clock) return (clock, hs_clock) + + +@attr.s +class FakeTransport(object): + """ + A twisted.internet.interfaces.ITransport implementation which sends all its data + straight into an IProtocol object: it exists to connect two IProtocols together. + + To use it, instantiate it with the receiving IProtocol, and then pass it to the + sending IProtocol's makeConnection method: + + server = HTTPChannel() + client.makeConnection(FakeTransport(server, self.reactor)) + + If you want bidirectional communication, you'll need two instances. + """ + + other = attr.ib() + """The Protocol object which will receive any data written to this transport. + + :type: twisted.internet.interfaces.IProtocol + """ + + _reactor = attr.ib() + """Test reactor + + :type: twisted.internet.interfaces.IReactorTime + """ + + disconnecting = False + buffer = attr.ib(default=b'') + producer = attr.ib(default=None) + + def getPeer(self): + return None + + def getHost(self): + return None + + def loseConnection(self): + self.disconnecting = True + + def abortConnection(self): + self.disconnecting = True + + def pauseProducing(self): + self.producer.pauseProducing() + + def unregisterProducer(self): + if not self.producer: + return + + self.producer = None + + def registerProducer(self, producer, streaming): + self.producer = producer + self.producerStreaming = streaming + + def _produce(): + d = self.producer.resumeProducing() + d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) + + if not streaming: + self._reactor.callLater(0.0, _produce) + + def write(self, byt): + self.buffer = self.buffer + byt + + def _write(): + if getattr(self.other, "transport") is not None: + self.other.dataReceived(self.buffer) + self.buffer = b"" + return + + self._reactor.callLater(0.0, _write) + + _write() + + def writeSequence(self, seq): + for x in seq: + self.write(x) diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py new file mode 100644 index 0000000000..95badc985e --- /dev/null +++ b/tests/server_notices/test_consent.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v2_alpha import sync + +from tests import unittest + + +class ConsentNoticesTests(unittest.HomeserverTestCase): + + servlets = [ + sync.register_servlets, + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + self.consent_notice_message = "consent %(consent_uri)s" + config = self.default_config() + config.user_consent_version = "1" + config.user_consent_server_notice_content = { + "msgtype": "m.text", + "body": self.consent_notice_message, + } + config.public_baseurl = "https://example.com/" + config.form_secret = "123abc" + + config.server_notices_mxid = "@notices:test" + config.server_notices_mxid_display_name = "test display name" + config.server_notices_mxid_avatar_url = None + config.server_notices_room_name = "Server Notices" + + hs = self.setup_test_homeserver(config=config) + + return hs + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("bob", "abc123") + self.access_token = self.login("bob", "abc123") + + def test_get_sync_message(self): + """ + When user consent server notices are enabled, a sync will cause a notice + to fire (in a room which the user is invited to). The notice contains + the notice URL + an authentication code. + """ + # Initial sync, to get the user consent room invite + request, channel = self.make_request( + "GET", "/_matrix/client/r0/sync", access_token=self.access_token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # Get the Room ID to join + room_id = list(channel.json_body["rooms"]["invite"].keys())[0] + + # Join the room + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/" + room_id + "/join", + access_token=self.access_token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # Sync again, to get the message in the room + request, channel = self.make_request( + "GET", "/_matrix/client/r0/sync", access_token=self.access_token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # Get the message + room = channel.json_body["rooms"]["join"][room_id] + messages = [ + x for x in room["timeline"]["events"] if x["type"] == "m.room.message" + ] + + # One message, with the consent URL + self.assertEqual(len(messages), 1) + self.assertTrue( + messages[0]["content"]["body"].startswith( + "consent https://example.com/_matrix/consent" + ) + ) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 5cc7fff39b..b1551df7ca 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -4,7 +4,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError -from synapse.handlers.auth import AuthHandler from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) @@ -13,17 +12,10 @@ from tests import unittest from tests.utils import setup_test_homeserver -class AuthHandlers(object): - def __init__(self, hs): - self.auth_handler = AuthHandler(hs) - - class TestResourceLimitsServerNotices(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) - self.hs.handlers = AuthHandlers(self.hs) - self.auth_handler = self.hs.handlers.auth_handler + self.hs = yield setup_test_homeserver(self.addCleanup) self.server_notices_sender = self.hs.get_server_notices_sender() # relying on [1] is far from ideal, but the only case where @@ -80,12 +72,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase): self._rlsn._auth.check_auth_blocking = Mock() mock_event = Mock( - type=EventTypes.Message, - content={"msgtype": ServerNoticeMsgType}, + type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} + ) + self._rlsn._store.get_events = Mock( + return_value=defer.succeed({"123": mock_event}) ) - self._rlsn._store.get_events = Mock(return_value=defer.succeed( - {"123": mock_event} - )) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) # Would be better to check the content, but once == remove blocking event @@ -99,12 +90,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase): ) mock_event = Mock( - type=EventTypes.Message, - content={"msgtype": ServerNoticeMsgType}, + type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} + ) + self._rlsn._store.get_events = Mock( + return_value=defer.succeed({"123": mock_event}) ) - self._rlsn._store.get_events = Mock(return_value=defer.succeed( - {"123": mock_event} - )) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) self._send_notice.assert_not_called() @@ -177,13 +167,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): @defer.inlineCallbacks def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock( - return_value=1000, - ) + self.store.get_monthly_active_count = Mock(return_value=1000) - self.store.user_last_seen_monthly_active = Mock( - return_value=1000, - ) + self.store.user_last_seen_monthly_active = Mock(return_value=1000) # Call the function multiple times to ensure we only send the notice once yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) @@ -193,12 +179,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): # Now lets get the last load of messages in the service notice room and # check that there is only one server notice room_id = yield self.server_notices_manager.get_notice_room_for_user( - self.user_id, + self.user_id ) token = yield self.event_source.get_current_token() events, _ = yield self.store.get_recent_events_for_room( - room_id, limit=100, end_token=token.room_key, + room_id, limit=100, end_token=token.room_key ) count = 0 diff --git a/tests/state/__init__.py b/tests/state/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/state/__init__.py diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py new file mode 100644 index 0000000000..efd85ebe6c --- /dev/null +++ b/tests/state/test_v2.py @@ -0,0 +1,663 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +from six.moves import zip + +import attr + +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.event_auth import auth_types_for_event +from synapse.events import FrozenEvent +from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store +from synapse.types import EventID + +from tests import unittest + +ALICE = "@alice:example.com" +BOB = "@bob:example.com" +CHARLIE = "@charlie:example.com" +EVELYN = "@evelyn:example.com" +ZARA = "@zara:example.com" + +ROOM_ID = "!test:example.com" + +MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN} +MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN} + + +ORIGIN_SERVER_TS = 0 + + +class FakeEvent(object): + """A fake event we use as a convenience. + + NOTE: Again as a convenience we use "node_ids" rather than event_ids to + refer to events. The event_id has node_id as localpart and example.com + as domain. + """ + def __init__(self, id, sender, type, state_key, content): + self.node_id = id + self.event_id = EventID(id, "example.com").to_string() + self.sender = sender + self.type = type + self.state_key = state_key + self.content = content + + def to_event(self, auth_events, prev_events): + """Given the auth_events and prev_events, convert to a Frozen Event + + Args: + auth_events (list[str]): list of event_ids + prev_events (list[str]): list of event_ids + + Returns: + FrozenEvent + """ + global ORIGIN_SERVER_TS + + ts = ORIGIN_SERVER_TS + ORIGIN_SERVER_TS = ORIGIN_SERVER_TS + 1 + + event_dict = { + "auth_events": [(a, {}) for a in auth_events], + "prev_events": [(p, {}) for p in prev_events], + "event_id": self.node_id, + "sender": self.sender, + "type": self.type, + "content": self.content, + "origin_server_ts": ts, + "room_id": ROOM_ID, + } + + if self.state_key is not None: + event_dict["state_key"] = self.state_key + + return FrozenEvent(event_dict) + + +# All graphs start with this set of events +INITIAL_EVENTS = [ + FakeEvent( + id="CREATE", + sender=ALICE, + type=EventTypes.Create, + state_key="", + content={"creator": ALICE}, + ), + FakeEvent( + id="IMA", + sender=ALICE, + type=EventTypes.Member, + state_key=ALICE, + content=MEMBERSHIP_CONTENT_JOIN, + ), + FakeEvent( + id="IPOWER", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100}}, + ), + FakeEvent( + id="IJR", + sender=ALICE, + type=EventTypes.JoinRules, + state_key="", + content={"join_rule": JoinRules.PUBLIC}, + ), + FakeEvent( + id="IMB", + sender=BOB, + type=EventTypes.Member, + state_key=BOB, + content=MEMBERSHIP_CONTENT_JOIN, + ), + FakeEvent( + id="IMC", + sender=CHARLIE, + type=EventTypes.Member, + state_key=CHARLIE, + content=MEMBERSHIP_CONTENT_JOIN, + ), + FakeEvent( + id="IMZ", + sender=ZARA, + type=EventTypes.Member, + state_key=ZARA, + content=MEMBERSHIP_CONTENT_JOIN, + ), + FakeEvent( + id="START", + sender=ZARA, + type=EventTypes.Message, + state_key=None, + content={}, + ), + FakeEvent( + id="END", + sender=ZARA, + type=EventTypes.Message, + state_key=None, + content={}, + ), +] + +INITIAL_EDGES = [ + "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE", +] + + +class StateTestCase(unittest.TestCase): + def test_ban_vs_pl(self): + events = [ + FakeEvent( + id="PA", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={ + "users": { + ALICE: 100, + BOB: 50, + } + }, + ), + FakeEvent( + id="MA", + sender=ALICE, + type=EventTypes.Member, + state_key=ALICE, + content={"membership": Membership.JOIN}, + ), + FakeEvent( + id="MB", + sender=ALICE, + type=EventTypes.Member, + state_key=BOB, + content={"membership": Membership.BAN}, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + ] + + edges = [ + ["END", "MB", "MA", "PA", "START"], + ["END", "PB", "PA"], + ] + + expected_state_ids = ["PA", "MA", "MB"] + + self.do_check(events, edges, expected_state_ids) + + def test_join_rule_evasion(self): + events = [ + FakeEvent( + id="JR", + sender=ALICE, + type=EventTypes.JoinRules, + state_key="", + content={"join_rules": JoinRules.PRIVATE}, + ), + FakeEvent( + id="ME", + sender=EVELYN, + type=EventTypes.Member, + state_key=EVELYN, + content={"membership": Membership.JOIN}, + ), + ] + + edges = [ + ["END", "JR", "START"], + ["END", "ME", "START"], + ] + + expected_state_ids = ["JR"] + + self.do_check(events, edges, expected_state_ids) + + def test_offtopic_pl(self): + events = [ + FakeEvent( + id="PA", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={ + "users": { + ALICE: 100, + BOB: 50, + } + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + CHARLIE: 50, + }, + }, + ), + FakeEvent( + id="PC", + sender=CHARLIE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + CHARLIE: 0, + }, + }, + ), + ] + + edges = [ + ["END", "PC", "PB", "PA", "START"], + ["END", "PA"], + ] + + expected_state_ids = ["PC"] + + self.do_check(events, edges, expected_state_ids) + + def test_topic_basic(self): + events = [ + FakeEvent( + id="T1", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA1", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T2", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA2", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 0, + }, + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T3", + sender=BOB, + type=EventTypes.Topic, + state_key="", + content={}, + ), + ] + + edges = [ + ["END", "PA2", "T2", "PA1", "T1", "START"], + ["END", "T3", "PB", "PA1"], + ] + + expected_state_ids = ["PA2", "T2"] + + self.do_check(events, edges, expected_state_ids) + + def test_topic_reset(self): + events = [ + FakeEvent( + id="T1", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T2", + sender=BOB, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="MB", + sender=ALICE, + type=EventTypes.Member, + state_key=BOB, + content={"membership": Membership.BAN}, + ), + ] + + edges = [ + ["END", "MB", "T2", "PA", "T1", "START"], + ["END", "T1"], + ] + + expected_state_ids = ["T1", "MB", "PA"] + + self.do_check(events, edges, expected_state_ids) + + def test_topic(self): + events = [ + FakeEvent( + id="T1", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA1", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T2", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA2", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 0, + }, + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T3", + sender=BOB, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="MZ1", + sender=ZARA, + type=EventTypes.Message, + state_key=None, + content={}, + ), + FakeEvent( + id="T4", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + ] + + edges = [ + ["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"], + ["END", "MZ1", "T3", "PB", "PA1"], + ] + + expected_state_ids = ["T4", "PA2"] + + self.do_check(events, edges, expected_state_ids) + + def do_check(self, events, edges, expected_state_ids): + """Take a list of events and edges and calculate the state of the + graph at END, and asserts it matches `expected_state_ids` + + Args: + events (list[FakeEvent]) + edges (list[list[str]]): A list of chains of event edges, e.g. + `[[A, B, C]]` are edges A->B and B->C. + expected_state_ids (list[str]): The expected state at END, (excluding + the keys that haven't changed since START). + """ + # We want to sort the events into topological order for processing. + graph = {} + + # node_id -> FakeEvent + fake_event_map = {} + + for ev in itertools.chain(INITIAL_EVENTS, events): + graph[ev.node_id] = set() + fake_event_map[ev.node_id] = ev + + for a, b in pairwise(INITIAL_EDGES): + graph[a].add(b) + + for edge_list in edges: + for a, b in pairwise(edge_list): + graph[a].add(b) + + # event_id -> FrozenEvent + event_map = {} + # node_id -> state + state_at_event = {} + + # We copy the map as the sort consumes the graph + graph_copy = {k: set(v) for k, v in graph.items()} + + for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e): + fake_event = fake_event_map[node_id] + event_id = fake_event.event_id + + prev_events = list(graph[node_id]) + + if len(prev_events) == 0: + state_before = {} + elif len(prev_events) == 1: + state_before = dict(state_at_event[prev_events[0]]) + else: + state_d = resolve_events_with_store( + [state_at_event[n] for n in prev_events], + event_map=event_map, + state_res_store=TestStateResolutionStore(event_map), + ) + + self.assertTrue(state_d.called) + state_before = state_d.result + + state_after = dict(state_before) + if fake_event.state_key is not None: + state_after[(fake_event.type, fake_event.state_key)] = event_id + + auth_types = set(auth_types_for_event(fake_event)) + + auth_events = [] + for key in auth_types: + if key in state_before: + auth_events.append(state_before[key]) + + event = fake_event.to_event(auth_events, prev_events) + + state_at_event[node_id] = state_after + event_map[event_id] = event + + expected_state = {} + for node_id in expected_state_ids: + # expected_state_ids are node IDs rather than event IDs, + # so we have to convert + event_id = EventID(node_id, "example.com").to_string() + event = event_map[event_id] + + key = (event.type, event.state_key) + + expected_state[key] = event_id + + start_state = state_at_event["START"] + end_state = { + key: value + for key, value in state_at_event["END"].items() + if key in expected_state or start_state.get(key) != value + } + + self.assertEqual(expected_state, end_state) + + +class LexicographicalTestCase(unittest.TestCase): + def test_simple(self): + graph = { + "l": {"o"}, + "m": {"n", "o"}, + "n": {"o"}, + "o": set(), + "p": {"o"}, + } + + res = list(lexicographical_topological_sort(graph, key=lambda x: x)) + + self.assertEqual(["o", "l", "n", "m", "p"], res) + + +def pairwise(iterable): + "s -> (s0,s1), (s1,s2), (s2, s3), ..." + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +@attr.s +class TestStateResolutionStore(object): + event_map = attr.ib() + + def get_events(self, event_ids, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + """ + + return { + eid: self.event_map[eid] + for eid in event_ids + if eid in self.event_map + } + + def get_auth_chain(self, event_ids): + """Gets the full auth chain for a set of events (including rejected + events). + + Includes the given event IDs in the result. + + Note that: + 1. All events must be state events. + 2. For v1 rooms this may not have the full auth chain in the + presence of rejected events + + Args: + event_ids (list): The event IDs of the events to fetch the auth + chain for. Must be state events. + + Returns: + Deferred[list[str]]: List of event IDs of the auth chain. + """ + + # Simple DFS for auth chain + result = set() + stack = list(event_ids) + while stack: + event_id = stack.pop() + if event_id in result: + continue + + result.add(event_id) + + event = self.event_map[event_id] + for aid, _ in event.auth_events: + stack.append(aid) + + return list(result) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c2e88bdbaf..4577e9422b 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,35 +13,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from mock import Mock from twisted.internet import defer -import tests.unittest -import tests.utils +from synapse.http.site import XForwardedForRequest +from synapse.rest.client.v1 import admin, login + +from tests import unittest -class ClientIpStoreTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - self.clock = None # type: tests.utils.MockClock +class ClientIpStoreTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs - @defer.inlineCallbacks - def setUp(self): - self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup) + def prepare(self, hs, reactor, clock): self.store = self.hs.get_datastore() - self.clock = self.hs.get_clock() - @defer.inlineCallbacks def test_insert_new_client_ip(self): - self.clock.now = 12345678 + self.reactor.advance(12345678) + user_id = "@user:id" - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - result = yield self.store.get_last_client_ip_by_device(user_id, "device_id") + # Trigger the storage loop + self.reactor.advance(10) + + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, "device_id") + ) r = result[(user_id, "device_id")] self.assertDictContainsSubset( @@ -55,18 +62,18 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): r, ) - @defer.inlineCallbacks def test_disabled_monthly_active_user(self): self.hs.config.limit_usage_by_mau = False self.hs.config.max_mau_value = 50 user_id = "@user:server" - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - @defer.inlineCallbacks def test_adding_monthly_active_user_when_full(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 @@ -76,40 +83,119 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(lots_of_users) ) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - @defer.inlineCallbacks def test_adding_monthly_active_user_when_space(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) - @defer.inlineCallbacks def test_updating_monthly_active_user_when_space(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" + self.get_success( + self.store.register(user_id=user_id, token="123", password_hash=None) + ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" - ) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + + +class ClientIpAuthTestCase(unittest.HomeserverTestCase): + + servlets = [admin.register_servlets, login.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def prepare(self, hs, reactor, clock): + self.store = self.hs.get_datastore() + self.user_id = self.register_user("bob", "abc123", True) + + def test_request_with_xforwarded(self): + """ + The IP in X-Forwarded-For is entered into the client IPs table. + """ + self._runtest( + {b"X-Forwarded-For": b"127.9.0.1"}, + "127.9.0.1", + {"request": XForwardedForRequest}, + ) + + def test_request_from_getPeer(self): + """ + The IP returned by getPeer is entered into the client IPs table, if + there's no X-Forwarded-For header. + """ + self._runtest({}, "127.0.0.1", {}) + + def _runtest(self, headers, expected_ip, make_request_args): + device_id = "bleb" + + access_token = self.login("bob", "abc123", device_id=device_id) + + # Advance to a known time + self.reactor.advance(123456 - self.reactor.seconds()) + + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/admin/users/" + self.user_id, + access_token=access_token, + **make_request_args + ) + request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") + + # Add the optional headers + for h, v in headers.items(): + request.requestHeaders.addRawHeader(h, v) + self.render(request) + + # Advance so the save loop occurs + self.reactor.advance(100) + + result = self.get_success( + self.store.get_last_client_ip_by_device(self.user_id, device_id) + ) + r = result[(self.user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": self.user_id, + "device_id": device_id, + "ip": expected_ip, + "user_agent": "Mozzila pizza", + "last_seen": 123456100, + }, + r, + ) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 2036287288..832e379a83 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -12,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + +from twisted.internet import defer from tests.unittest import HomeserverTestCase @@ -23,7 +26,8 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): hs = self.setup_test_homeserver() self.store = hs.get_datastore() - + hs.config.limit_usage_by_mau = True + hs.config.max_mau_value = 50 # Advance the clock a bit reactor.advance(FORTY_DAYS) @@ -48,7 +52,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - self.store.initialise_reserved_users(threepids) + + self.store.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) self.pump() active_count = self.store.get_monthly_active_count() @@ -73,7 +80,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): active_count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(active_count), user_num) - # Test that regalar users are removed from the db + # Test that regular users are removed from the db ru_count = 2 self.store.upsert_monthly_active_user("@ru1:server") self.store.upsert_monthly_active_user("@ru2:server") @@ -139,3 +146,77 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(count), 0) + + def test_populate_monthly_users_is_guest(self): + # Test that guest users are not added to mau list + user_id = "user_id" + self.store.register( + user_id=user_id, token="123", password_hash=None, make_guest=True + ) + self.store.upsert_monthly_active_user = Mock() + self.store.populate_monthly_active_users(user_id) + self.pump() + self.store.upsert_monthly_active_user.assert_not_called() + + def test_populate_monthly_users_should_update(self): + self.store.upsert_monthly_active_user = Mock() + + self.store.is_trial_user = Mock( + return_value=defer.succeed(False) + ) + + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(None) + ) + self.store.populate_monthly_active_users('user_id') + self.pump() + self.store.upsert_monthly_active_user.assert_called_once() + + def test_populate_monthly_users_should_not_update(self): + self.store.upsert_monthly_active_user = Mock() + + self.store.is_trial_user = Mock( + return_value=defer.succeed(False) + ) + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed( + self.hs.get_clock().time_msec() + ) + ) + self.store.populate_monthly_active_users('user_id') + self.pump() + self.store.upsert_monthly_active_user.assert_not_called() + + def test_get_reserved_real_user_account(self): + # Test no reserved users, or reserved threepids + count = self.store.get_registered_reserved_users_count() + self.assertEquals(self.get_success(count), 0) + # Test reserved users but no registered users + + user1 = '@user1:example.com' + user2 = '@user2:example.com' + user1_email = 'user1@example.com' + user2_email = 'user2@example.com' + threepids = [ + {'medium': 'email', 'address': user1_email}, + {'medium': 'email', 'address': user2_email}, + ] + self.hs.config.mau_limits_reserved_threepids = threepids + self.store.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) + + self.pump() + count = self.store.get_registered_reserved_users_count() + self.assertEquals(self.get_success(count), 0) + + # Test reserved registed users + self.store.register(user_id=user1, token="123", password_hash=None) + self.store.register(user_id=user2, token="456", password_hash=None) + self.pump() + + now = int(self.hs.get_clock().time_msec()) + self.store.user_add_threepid(user1, "email", user1_email, now, now) + self.store.user_add_threepid(user2, "email", user2_email, now, now) + count = self.store.get_registered_reserved_users_count() + self.assertEquals(self.get_success(count), len(threepids)) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d717b9f94e..086a39d834 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -18,6 +18,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID import tests.unittest @@ -75,6 +76,45 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(len(s1), len(s2)) @defer.inlineCallbacks + def test_get_state_groups_ids(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_map = list(state_group_map.values())[0] + self.assertDictEqual( + state_map, + { + (EventTypes.Create, ''): e1.event_id, + (EventTypes.Name, ''): e2.event_id, + }, + ) + + @defer.inlineCallbacks + def test_get_state_groups(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups( + self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_list = list(state_group_map.values())[0] + + self.assertEqual( + {ev.event_id for ev in state_list}, + {e1.event_id, e2.event_id}, + ) + + @defer.inlineCallbacks def test_get_state_for_event(self): # this defaults to a linear DAG as each new injection defaults to whatever @@ -109,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we get the full state as of the final event state = yield self.store.get_state_for_event( - e5.event_id, None, filtered_types=None + e5.event_id, ) self.assertIsNotNone(e4) @@ -127,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can filter to the m.room.name event (with a '' state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, '')], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Member, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) self.assertStateMapEqual( {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state ) - # check we can use filtered_types to grab a specific room member - # without filtering out the other event types + # check we can grab a specific room member without filtering out the + # other event types state = yield self.store.get_state_for_event( e5.event_id, - [(EventTypes.Member, self.u_alice.to_string())], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ) ) self.assertStateMapEqual( @@ -165,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase): state, ) - # check that types=[], filtered_types=[EventTypes.Member] - # doesn't return all members + # check that we can grab everything except members state = yield self.store.get_state_for_event( - e5.event_id, [], filtered_types=[EventTypes.Member] + e5.event_id, state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertStateMapEqual( @@ -176,17 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) ####################################################### - # _get_some_state_from_cache tests against a full cache + # _get_state_for_group_using_cache tests against a full cache ####################################################### room_id = self.room.to_string() group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) group = list(group_ids.keys())[0] - # test _get_some_state_from_cache correctly filters out members with types=[] - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, - group, [], filtered_types=[EventTypes.Member] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -198,21 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, - group, [], filtered_types=[EventTypes.Member] + group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) - self.assertDictEqual( - {}, - state_dict, - ) + self.assertDictEqual({}, state_dict) - # test _get_some_state_from_cache correctly filters in members with wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + group, + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -224,9 +278,13 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + group, + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -239,12 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -256,26 +317,27 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) - self.assertDictEqual( - { - (e5.type, e5.state_key): e5.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, - group, [(EventTypes.Member, e5.state_key)], filtered_types=None + group, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) @@ -305,9 +367,7 @@ class StateStoreTestCase(tests.unittest.TestCase): key=group, value=state_dict_ids, # list fetched keys so it knows it's partial - fetched_keys=( - (e1.type, e1.state_key), - ), + fetched_keys=((e1.type, e1.state_key),), ) (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( @@ -315,60 +375,60 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, False) - self.assertEqual( - known_absent, - set( - [ - (e1.type, e1.state_key), - ] - ), - ) - self.assertDictEqual( - state_dict_ids, - { - (e1.type, e1.state_key): e1.event_id, - }, - ) + self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) + self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ # test that things work with a partial cache - # test _get_some_state_from_cache correctly filters out members with types=[] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - self.store._state_group_cache, - group, [], filtered_types=[EventTypes.Member] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, - group, [], filtered_types=[EventTypes.Member] + group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) self.assertDictEqual({}, state_dict) - # test _get_some_state_from_cache correctly filters in members wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + group, + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, False) - self.assertDictEqual( - { - (e1.type, e1.state_key): e1.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + group, + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -380,56 +440,54 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, False) - self.assertDictEqual( - { - (e1.type, e1.state_key): e1.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) - self.assertDictEqual( - { - (e5.type, e5.state_key): e5.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_cache, - group, [(EventTypes.Member, e5.state_key)], filtered_types=None + group, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( self.store._state_group_members_cache, - group, [(EventTypes.Member, e5.state_key)], filtered_types=None + group, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) - self.assertDictEqual( - { - (e5.type, e5.state_key): e5.event_id, - }, - state_dict, - ) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py new file mode 100644 index 0000000000..14169afa96 --- /dev/null +++ b/tests/storage/test_transactions.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.unittest import HomeserverTestCase + + +class TransactionStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() + + def test_get_set_transactions(self): + """Tests that we can successfully get a non-existent entry for + destination retries, as well as testing tht we can set and get + correctly. + """ + d = self.store.get_destination_retry_timings("example.com") + r = self.get_success(d) + self.assertIsNone(r) + + d = self.store.set_destination_retry_timings("example.com", 50, 100) + self.get_success(d) + + d = self.store.get_destination_retry_timings("example.com") + r = self.get_success(d) + + self.assert_dict({"retry_last_ts": 50, "retry_interval": 100}, r) + + def test_initial_set_transactions(self): + """Tests that we can successfully set the destination retries (there + was a bug around invalidating the cache that broke this) + """ + d = self.store.set_destination_retry_timings("example.com", 50, 100) + self.get_success(d) diff --git a/tests/test_federation.py b/tests/test_federation.py index 2540604fcc..952a0a7b51 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -6,6 +6,7 @@ from twisted.internet.defer import maybeDeferred, succeed from synapse.events import FrozenEvent from synapse.types import Requester, UserID from synapse.util import Clock +from synapse.util.logcontext import LoggingContext from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver @@ -117,9 +118,10 @@ class MessageAcceptTests(unittest.TestCase): } ) - d = self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True - ) + with LoggingContext(request="lying_event"): + d = self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) # Step the reactor, so the database fetches come back self.reactor.advance(1) @@ -139,107 +141,3 @@ class MessageAcceptTests(unittest.TestCase): self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") - - def test_cant_hide_past_history(self): - """ - If you send a message, you must be able to provide the direct - prev_events that said event references. - """ - - def post_json(destination, path, data, headers=None, timeout=0): - if path.startswith("/_matrix/federation/v1/get_missing_events/"): - return { - "events": [ - { - "room_id": self.room_id, - "sender": "@baduser:test.serv", - "event_id": "three:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.message", - "origin": "test.serv", - "content": "hewwo?", - "auth_events": [], - "prev_events": [("four:test.serv", {})], - } - ] - } - - self.http_client.post_json = post_json - - def get_json(destination, path, args, headers=None): - if path.startswith("/_matrix/federation/v1/state_ids/"): - d = self.successResultOf( - self.homeserver.datastore.get_state_ids_for_event("one:test.serv") - ) - - return succeed( - { - "pdu_ids": [ - y - for x, y in d.items() - if x == ("m.room.member", "@us:test") - ], - "auth_chain_ids": list(d.values()), - } - ) - - self.http_client.get_json = get_json - - # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id - ) - )[0] - - # Make a good event - good_event = FrozenEvent( - { - "room_id": self.room_id, - "sender": "@baduser:test.serv", - "event_id": "one:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.message", - "origin": "test.serv", - "content": "hewwo?", - "auth_events": [], - "prev_events": [(most_recent, {})], - } - ) - - d = self.handler.on_receive_pdu( - "test.serv", good_event, sent_to_us_directly=True - ) - self.reactor.advance(1) - self.assertEqual(self.successResultOf(d), None) - - bad_event = FrozenEvent( - { - "room_id": self.room_id, - "sender": "@baduser:test.serv", - "event_id": "two:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.message", - "origin": "test.serv", - "content": "hewwo?", - "auth_events": [], - "prev_events": [("one:test.serv", {}), ("three:test.serv", {})], - } - ) - - d = self.handler.on_receive_pdu( - "test.serv", bad_event, sent_to_us_directly=True - ) - self.reactor.advance(1) - - extrem = maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id - ) - self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv") - - state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id) - self.reactor.advance(1) - self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys()) diff --git a/tests/test_mau.py b/tests/test_mau.py index 0732615447..5d387851c5 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -185,20 +185,20 @@ class TestMauLimit(unittest.TestCase): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def create_user(self, localpart): - request_data = json.dumps({ - "username": localpart, - "password": "monkey", - "auth": {"type": LoginType.DUMMY}, - }) + request_data = json.dumps( + { + "username": localpart, + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + } + ) - request, channel = make_request(b"POST", b"/register", request_data) + request, channel = make_request("POST", "/register", request_data) render(request, self.resource, self.reactor) - if channel.result["code"] != b"200": + if channel.code != 200: raise HttpResponseException( - int(channel.result["code"]), - channel.result["reason"], - channel.result["body"], + channel.code, channel.result["reason"], channel.result["body"] ).to_synapse_error() access_token = channel.json_body["access_token"] @@ -206,12 +206,12 @@ class TestMauLimit(unittest.TestCase): return access_token def do_sync_for_user(self, token): - request, channel = make_request(b"GET", b"/sync", access_token=token) + request, channel = make_request( + "GET", "/sync", access_token=token + ) render(request, self.resource, self.reactor) - if channel.result["code"] != b"200": + if channel.code != 200: raise HttpResponseException( - int(channel.result["code"]), - channel.result["reason"], - channel.result["body"], + channel.code, channel.result["reason"], channel.result["body"] ).to_synapse_error() diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000000..17897711a1 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.metrics import InFlightGauge + +from tests import unittest + + +class TestMauLimit(unittest.TestCase): + def test_basic(self): + gauge = InFlightGauge( + "test1", "", + labels=["test_label"], + sub_metrics=["foo", "bar"], + ) + + def handle1(metrics): + metrics.foo += 2 + metrics.bar = max(metrics.bar, 5) + + def handle2(metrics): + metrics.foo += 3 + metrics.bar = max(metrics.bar, 7) + + gauge.register(("key1",), handle1) + + self.assert_dict({ + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, self.get_metrics_from_gauge(gauge)) + + gauge.unregister(("key1",), handle1) + + self.assert_dict({ + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, self.get_metrics_from_gauge(gauge)) + + gauge.register(("key1",), handle1) + gauge.register(("key2",), handle2) + + self.assert_dict({ + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, self.get_metrics_from_gauge(gauge)) + + gauge.unregister(("key2",), handle2) + gauge.register(("key1",), handle2) + + self.assert_dict({ + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, self.get_metrics_from_gauge(gauge)) + + def get_metrics_from_gauge(self, gauge): + results = {} + + for r in gauge.collect(): + results[r.name] = { + tuple(labels[x] for x in gauge.labels): value + for _, labels, value in r.samples + } + + return results diff --git a/tests/test_server.py b/tests/test_server.py index ef74544e93..4045fdadc3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,14 +1,35 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging import re +from six import StringIO + from twisted.internet.defer import Deferred -from twisted.test.proto_helpers import MemoryReactorClock +from twisted.python.failure import Failure +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource +from synapse.http.site import SynapseSite, logger from synapse.util import Clock from tests import unittest -from tests.server import make_request, render, setup_test_homeserver +from tests.server import FakeTransport, make_request, render, setup_test_homeserver class JsonResourceTests(unittest.TestCase): @@ -121,3 +142,52 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + +class SiteTestCase(unittest.HomeserverTestCase): + def test_lose_connection(self): + """ + We log the URI correctly redacted when we lose the connection. + """ + + class HangingResource(Resource): + """ + A Resource that strategically hangs, as if it were processing an + answer. + """ + + def render(self, request): + return NOT_DONE_YET + + # Set up a logging handler that we can inspect afterwards + output = StringIO() + handler = logging.StreamHandler(output) + logger.addHandler(handler) + old_level = logger.level + logger.setLevel(10) + self.addCleanup(logger.setLevel, old_level) + self.addCleanup(logger.removeHandler, handler) + + # Make a resource and a Site, the resource will hang and allow us to + # time out the request while it's 'processing' + base_resource = Resource() + base_resource.putChild(b'', HangingResource()) + site = SynapseSite("test", "site_tag", {}, base_resource, "1.0") + + server = site.buildProtocol(None) + client = AccumulatingProtocol() + client.makeConnection(FakeTransport(server, self.reactor)) + server.makeConnection(FakeTransport(client, self.reactor)) + + # Send a request with an access token that will get redacted + server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n") + self.pump() + + # Lose the connection + e = Failure(Exception("Failed123")) + server.connectionLost(e) + handler.flush() + + # Our access token is redacted and the failure reason is logged. + self.assertIn("/?access_token=<redacted>", output.getvalue()) + self.assertIn("Failed123", output.getvalue()) diff --git a/tests/test_state.py b/tests/test_state.py index 452a123c3a..e20c33322a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -180,7 +180,7 @@ class StateTestCase(unittest.TestCase): graph = Graph( nodes={ "START": DictObj( - type=EventTypes.Create, state_key="", content={}, depth=1, + type=EventTypes.Create, state_key="", content={}, depth=1 ), "A": DictObj(type=EventTypes.Message, depth=2), "B": DictObj(type=EventTypes.Message, depth=3), diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py new file mode 100644 index 0000000000..7deab5266f --- /dev/null +++ b/tests/test_terms_auth.py @@ -0,0 +1,123 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import six +from mock import Mock + +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.rest.client.v2_alpha.register import register_servlets +from synapse.util import Clock + +from tests import unittest +from tests.server import make_request + + +class TermsTestCase(unittest.HomeserverTestCase): + servlets = [register_servlets] + + def prepare(self, reactor, clock, hs): + self.clock = MemoryReactorClock() + self.hs_clock = Clock(self.clock) + self.url = "/_matrix/client/r0/register" + self.registration_handler = Mock() + self.auth_handler = Mock() + self.device_handler = Mock() + hs.config.enable_registration = True + hs.config.registrations_require_3pid = [] + hs.config.auto_join_rooms = [] + hs.config.enable_registration_captcha = False + + def test_ui_auth(self): + self.hs.config.block_events_without_consent_error = True + self.hs.config.public_baseurl = "https://example.org" + self.hs.config.user_consent_version = "1.0" + + # Do a UI auth request + request, channel = self.make_request(b"POST", self.url, b"{}") + self.render(request) + + self.assertEquals(channel.result["code"], b"401", channel.result) + + self.assertTrue(channel.json_body is not None) + self.assertIsInstance(channel.json_body["session"], six.text_type) + + self.assertIsInstance(channel.json_body["flows"], list) + for flow in channel.json_body["flows"]: + self.assertIsInstance(flow["stages"], list) + self.assertTrue(len(flow["stages"]) > 0) + self.assertEquals(flow["stages"][-1], "m.login.terms") + + expected_params = { + "m.login.terms": { + "policies": { + "privacy_policy": { + "en": { + "name": "Privacy Policy", + "url": "https://example.org/_matrix/consent?v=1.0", + }, + "version": "1.0" + }, + }, + }, + } + self.assertIsInstance(channel.json_body["params"], dict) + self.assertDictContainsSubset(channel.json_body["params"], expected_params) + + # We have to complete the dummy auth stage before completing the terms stage + request_data = json.dumps( + { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.dummy", + }, + } + ) + + self.registration_handler.check_username = Mock(return_value=True) + + request, channel = make_request(b"POST", self.url, request_data) + self.render(request) + + # We don't bother checking that the response is correct - we'll leave that to + # other tests. We just want to make sure we're on the right path. + self.assertEquals(channel.result["code"], b"401", channel.result) + + # Finish the UI auth for terms + request_data = json.dumps( + { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.terms", + }, + } + ) + request, channel = make_request(b"POST", self.url, request_data) + self.render(request) + + # We're interested in getting a response that looks like a successful + # registration, not so much that the details are exactly what we want. + + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertTrue(channel.json_body is not None) + self.assertIsInstance(channel.json_body["user_id"], six.text_type) + self.assertIsInstance(channel.json_body["access_token"], six.text_type) + self.assertIsInstance(channel.json_body["device_id"], six.text_type) diff --git a/tests/unittest.py b/tests/unittest.py index a3d39920db..4d40bdb6a5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import hmac import logging from mock import Mock @@ -26,11 +28,13 @@ from twisted.internet.defer import Deferred from twisted.trial import unittest from synapse.http.server import JsonResource +from synapse.http.site import SynapseRequest from synapse.server import HomeServer from synapse.types import UserID, create_requester from synapse.util.logcontext import LoggingContextFilter from tests.server import get_clock, make_request, render, setup_test_homeserver +from tests.utils import default_config # Set up putting Synapse's logs into Trial's. rootLogger = logging.getLogger() @@ -142,6 +146,13 @@ def DEBUG(target): return target +def INFO(target): + """A decorator to set the .loglevel attribute to logging.INFO. + Can apply to either a TestCase or an individual test method.""" + target.loglevel = logging.INFO + return target + + class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. @@ -219,7 +230,17 @@ class HomeserverTestCase(TestCase): Function to be overridden in subclasses. """ - raise NotImplementedError() + hs = self.setup_test_homeserver() + return hs + + def default_config(self, name="test"): + """ + Get a default HomeServer config object. + + Args: + name (str): The homeserver name/domain. + """ + return default_config(name) def prepare(self, reactor, clock, homeserver): """ @@ -236,7 +257,9 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ - def make_request(self, method, path, content=b""): + def make_request( + self, method, path, content=b"", access_token=None, request=SynapseRequest + ): """ Create a SynapseRequest at the path using the method and containing the given content. @@ -254,7 +277,7 @@ class HomeserverTestCase(TestCase): if isinstance(content, dict): content = json.dumps(content).encode('utf8') - return make_request(method, path, content) + return make_request(method, path, content, access_token, request) def render(self, request): """ @@ -293,3 +316,69 @@ class HomeserverTestCase(TestCase): return d self.pump() return self.successResultOf(d) + + def register_user(self, username, password, admin=False): + """ + Register a user. Requires the Admin API be registered. + + Args: + username (bytes/unicode): The user part of the new user. + password (bytes/unicode): The password of the new user. + admin (bool): Whether the user should be created as an admin + or not. + + Returns: + The MXID of the new user (unicode). + """ + self.hs.config.registration_shared_secret = u"shared" + + # Create the user + request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')]) + if admin: + nonce_str += b"\x00admin" + else: + nonce_str += b"\x00notadmin" + want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str) + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": username, + "password": password, + "admin": admin, + "mac": want_mac, + } + ) + request, channel = self.make_request( + "POST", "/_matrix/client/r0/admin/register", body.encode('utf8') + ) + self.render(request) + self.assertEqual(channel.code, 200) + + user_id = channel.json_body["user_id"] + return user_id + + def login(self, username, password, device_id=None): + """ + Log in a user, and get an access token. Requires the Login API be + registered. + + """ + body = {"type": "m.login.password", "user": username, "password": password} + if device_id: + body["device_id"] = device_id + + request, channel = self.make_request( + "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + ) + self.render(request) + self.assertEqual(channel.code, 200) + + access_token = channel.json_body["access_token"] + return access_token diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 5cbada4eda..50bc7702d2 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -65,7 +65,6 @@ class ExpiringCacheTestCase(unittest.TestCase): def test_time_eviction(self): clock = MockClock() cache = ExpiringCache("test", clock, expiry_ms=1000) - cache.start() cache["key"] = 1 clock.advance_time(0.5) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 4633db77b3..8adaee3c8d 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -159,6 +159,11 @@ class LoggingContextTestCase(unittest.TestCase): self.assertEqual(r, "bum") self._check_test_key("one") + def test_nested_logging_context(self): + with LoggingContext(request="foo"): + nested_context = logcontext.nested_logging_context(suffix="bar") + self.assertEqual(nested_context.request, "foo-bar") + # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on diff --git a/tests/utils.py b/tests/utils.py index b85017d279..565bb60d08 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,9 @@ import atexit import hashlib import os +import time import uuid +import warnings from inspect import getcallargs from mock import Mock, patch @@ -94,14 +96,78 @@ def setupdb(): atexit.register(_cleanup) +def default_config(name): + """ + Create a reasonable test config. + """ + config = Mock() + config.signing_key = [MockKey()] + config.event_cache_size = 1 + config.enable_registration = True + config.macaroon_secret_key = "not even a little secret" + config.expire_access_token = False + config.server_name = name + config.trusted_third_party_id_servers = [] + config.room_invite_state_types = [] + config.password_providers = [] + config.worker_replication_url = "" + config.worker_app = None + config.email_enable_notifs = False + config.block_non_admin_invites = False + config.federation_domain_whitelist = None + config.federation_rc_reject_limit = 10 + config.federation_rc_sleep_limit = 10 + config.federation_rc_sleep_delay = 100 + config.federation_rc_concurrent = 10 + config.filter_timeline_limit = 5000 + config.user_directory_search_all_users = False + config.user_consent_server_notice_content = None + config.block_events_without_consent_error = None + config.media_storage_providers = [] + config.autocreate_auto_join_rooms = True + config.auto_join_rooms = [] + config.limit_usage_by_mau = False + config.hs_disabled = False + config.hs_disabled_message = "" + config.hs_disabled_limit_type = "" + config.max_mau_value = 50 + config.mau_trial_days = 0 + config.mau_limits_reserved_threepids = [] + config.admin_contact = None + config.rc_messages_per_second = 10000 + config.rc_message_burst_count = 10000 + + config.use_frozen_dicts = False + + # we need a sane default_room_version, otherwise attempts to create rooms will + # fail. + config.default_room_version = "1" + + # disable user directory updates, because they get done in the + # background, which upsets the test runner. + config.update_user_directory = False + + def is_threepid_reserved(threepid): + return ServerConfig.is_threepid_reserved(config, threepid) + + config.is_threepid_reserved.side_effect = is_threepid_reserved + + return config + + class TestHomeServer(HomeServer): DATASTORE_CLASS = DataStore @defer.inlineCallbacks def setup_test_homeserver( - cleanup_func, name="test", datastore=None, config=None, reactor=None, - homeserverToUse=TestHomeServer, **kargs + cleanup_func, + name="test", + datastore=None, + config=None, + reactor=None, + homeserverToUse=TestHomeServer, + **kargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -117,55 +183,8 @@ def setup_test_homeserver( from twisted.internet import reactor if config is None: - config = Mock() - config.signing_key = [MockKey()] - config.event_cache_size = 1 - config.enable_registration = True - config.macaroon_secret_key = "not even a little secret" - config.expire_access_token = False - config.server_name = name - config.trusted_third_party_id_servers = [] - config.room_invite_state_types = [] - config.password_providers = [] - config.worker_replication_url = "" - config.worker_app = None - config.email_enable_notifs = False - config.block_non_admin_invites = False - config.federation_domain_whitelist = None - config.federation_rc_reject_limit = 10 - config.federation_rc_sleep_limit = 10 - config.federation_rc_sleep_delay = 100 - config.federation_rc_concurrent = 10 - config.filter_timeline_limit = 5000 - config.user_directory_search_all_users = False - config.user_consent_server_notice_content = None - config.block_events_without_consent_error = None - config.media_storage_providers = [] - config.auto_join_rooms = [] - config.limit_usage_by_mau = False - config.hs_disabled = False - config.hs_disabled_message = "" - config.hs_disabled_limit_type = "" - config.max_mau_value = 50 - config.mau_limits_reserved_threepids = [] - config.admin_contact = None - config.rc_messages_per_second = 10000 - config.rc_message_burst_count = 10000 - - # we need a sane default_room_version, otherwise attempts to create rooms will - # fail. - config.default_room_version = "1" - - # disable user directory updates, because they get done in the - # background, which upsets the test runner. - config.update_user_directory = False - - def is_threepid_reserved(threepid): - return ServerConfig.is_threepid_reserved(config, threepid) - - config.is_threepid_reserved.side_effect = is_threepid_reserved - - config.use_frozen_dicts = True + config = default_config(name) + config.ldap_enabled = False if "clock" not in kargs: @@ -231,20 +250,41 @@ def setup_test_homeserver( else: # We need to do cleanup on PostgreSQL def cleanup(): + import psycopg2 + # Close all the db pools hs.get_db_pool().close() + dropped = False + # Drop the test database db_conn = db_engine.module.connect( database=POSTGRES_BASE_DB, user=POSTGRES_USER ) db_conn.autocommit = True cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for x in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + cur.close() db_conn.close() + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + if not LEAVE_DB: # Register the cleanup hook cleanup_func(cleanup) @@ -322,8 +362,7 @@ class MockHttpResource(HttpServer): @patch('twisted.web.http.Request') @defer.inlineCallbacks def trigger( - self, http_method, path, content, mock_request, - federation_auth_origin=None, + self, http_method, path, content, mock_request, federation_auth_origin=None ): """ Fire an HTTP event. @@ -356,7 +395,7 @@ class MockHttpResource(HttpServer): headers = {} if federation_auth_origin is not None: headers[b"Authorization"] = [ - b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin, ) + b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,) ] mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) @@ -576,16 +615,16 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - builder = event_builder_factory.new({ - "type": EventTypes.Create, - "state_key": "", - "sender": creator_id, - "room_id": room_id, - "content": {}, - }) - - event, context = yield event_creation_handler.create_new_client_event( - builder + builder = event_builder_factory.new( + { + "type": EventTypes.Create, + "state_key": "", + "sender": creator_id, + "room_id": room_id, + "content": {}, + } ) + event, context = yield event_creation_handler.create_new_client_event(builder) + yield store.persist_event(event, context) |