diff options
Diffstat (limited to 'tests')
56 files changed, 4024 insertions, 983 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 022d81ce3e..379e9c4ab1 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -457,8 +457,8 @@ class AuthTestCase(unittest.TestCase): with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() - self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) # Ensure does not throw an error @@ -468,11 +468,37 @@ class AuthTestCase(unittest.TestCase): yield self.auth.check_auth_blocking() @defer.inlineCallbacks + def test_reserved_threepid(self): + self.hs.config.limit_usage_by_mau = True + self.hs.config.max_mau_value = 1 + self.store.get_monthly_active_count = lambda: defer.succeed(2) + threepid = {'medium': 'email', 'address': 'reserved@server.com'} + unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} + self.hs.config.mau_limits_reserved_threepids = [threepid] + + yield self.store.register(user_id='user1', token="123", password_hash=None) + with self.assertRaises(ResourceLimitError): + yield self.auth.check_auth_blocking() + + with self.assertRaises(ResourceLimitError): + yield self.auth.check_auth_blocking(threepid=unknown_threepid) + + yield self.auth.check_auth_blocking(threepid=threepid) + + @defer.inlineCallbacks def test_hs_disabled(self): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: yield self.auth.check_auth_blocking() - self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) + + @defer.inlineCallbacks + def test_server_notices_mxid_special_cased(self): + self.hs.config.hs_disabled = True + user = "@user:server" + self.hs.config.server_notices_mxid = user + self.hs.config.hs_disabled_message = "Reason for being disabled" + yield self.auth.check_auth_blocking(user) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 48b2d3d663..2a7044801a 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -60,7 +60,7 @@ class FilteringTestCase(unittest.TestCase): invalid_filters = [ {"boom": {}}, {"account_data": "Hello World"}, - {"event_fields": ["\\foo"]}, + {"event_fields": [r"\\foo"]}, {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, {"event_format": "other"}, {"room": {"not_rooms": ["#foo:pik-test"]}}, @@ -109,6 +109,16 @@ class FilteringTestCase(unittest.TestCase): "event_format": "client", "event_fields": ["type", "content", "sender"], }, + + # a single backslash should be permitted (though it is debatable whether + # it should be permitted before anything other than `.`, and what that + # actually means) + # + # (note that event_fields is implemented in + # synapse.events.utils.serialize_event, and so whether this actually works + # is tested elsewhere. We just want to check that it is allowed through the + # filter validation) + {"event_fields": [r"foo\.bar"]}, ] for filter in valid_filters: try: diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py index 76b5090fff..a83f567ebd 100644 --- a/tests/app/test_frontend_proxy.py +++ b/tests/app/test_frontend_proxy.py @@ -47,7 +47,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] self.resource = ( - site.resource.children["_matrix"].children["client"].children["r0"] + site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] ) request, channel = self.make_request("PUT", "presence/a/status") @@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase): self.assertEqual(len(self.reactor.tcpServers), 1) site = self.reactor.tcpServers[0][1] self.resource = ( - site.resource.children["_matrix"].children["client"].children["r0"] + site.resource.children[b"_matrix"].children[b"client"].children[b"r0"] ) request, channel = self.make_request("PUT", "presence/a/status") diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index f88d28a19d..0c23068bcf 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -67,6 +67,6 @@ class ConfigGenerationTestCase(unittest.TestCase): with open(log_config_file) as f: config = f.read() # find the 'filename' line - matches = re.findall("^\s*filename:\s*(.*)$", config, re.M) + matches = re.findall(r"^\s*filename:\s*(.*)$", config, re.M) self.assertEqual(1, len(matches)) self.assertEqual(matches[0], expected) diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py new file mode 100644 index 0000000000..f37a17d618 --- /dev/null +++ b/tests/config/test_room_directory.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml + +from synapse.config.room_directory import RoomDirectoryConfig + +from tests import unittest + + +class RoomDirectoryConfigTestCase(unittest.TestCase): + def test_alias_creation_acl(self): + config = yaml.load(""" + alias_creation_rules: + - user_id: "*bob*" + alias: "*" + action: "deny" + - user_id: "*" + alias: "#unofficial_*" + action: "allow" + - user_id: "@foo*:example.com" + alias: "*" + action: "allow" + - user_id: "@gah:example.com" + alias: "#goo:example.com" + action: "allow" + """) + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.assertFalse(rd_config.is_alias_creation_allowed( + user_id="@bob:example.com", + alias="#test:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + alias="#unofficial_st:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@foobar:example.com", + alias="#test:example.com", + )) + + self.assertTrue(rd_config.is_alias_creation_allowed( + user_id="@gah:example.com", + alias="#goo:example.com", + )) + + self.assertFalse(rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + alias="#test:example.com", + )) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index ff217ca8b9..d0cc492deb 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -156,7 +156,7 @@ class SerializeEventTestCase(unittest.TestCase): room_id="!foo:bar", content={"key.with.dots": {}}, ), - ["content.key\.with\.dots"], + [r"content.key\.with\.dots"], ), {"content": {"key.with.dots": {}}}, ) @@ -172,7 +172,7 @@ class SerializeEventTestCase(unittest.TestCase): "nested.dot.key": {"leaf.key": 42, "not_me_either": 1}, }, ), - ["content.nested\.dot\.key.leaf\.key"], + [r"content.nested\.dot\.key.leaf\.key"], ), {"content": {"nested.dot.key": {"leaf.key": 42}}}, ) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 56e7acd37c..a3aa0a1cf2 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,79 +14,79 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import synapse.api.errors import synapse.handlers.device import synapse.storage -from tests import unittest, utils +from tests import unittest user1 = "@boris:aaa" user2 = "@theresa:bbb" -class DeviceTestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(DeviceTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - self.handler = None # type: synapse.handlers.device.DeviceHandler - self.clock = None # type: utils.MockClock - - @defer.inlineCallbacks - def setUp(self): - hs = yield utils.setup_test_homeserver(self.addCleanup) +class DeviceTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) self.handler = hs.get_device_handler() self.store = hs.get_datastore() - self.clock = hs.get_clock() + return hs + + def prepare(self, reactor, clock, hs): + # These tests assume that it starts 1000 seconds in. + self.reactor.advance(1000) - @defer.inlineCallbacks def test_device_is_created_if_doesnt_exist(self): - res = yield self.handler.check_device_registered( - user_id="@boris:foo", - device_id="fco", - initial_device_display_name="display name", + res = self.get_success( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="fco", + initial_device_display_name="display name", + ) ) self.assertEqual(res, "fco") - dev = yield self.handler.store.get_device("@boris:foo", "fco") + dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - @defer.inlineCallbacks def test_device_is_preserved_if_exists(self): - res1 = yield self.handler.check_device_registered( - user_id="@boris:foo", - device_id="fco", - initial_device_display_name="display name", + res1 = self.get_success( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="fco", + initial_device_display_name="display name", + ) ) self.assertEqual(res1, "fco") - res2 = yield self.handler.check_device_registered( - user_id="@boris:foo", - device_id="fco", - initial_device_display_name="new display name", + res2 = self.get_success( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="fco", + initial_device_display_name="new display name", + ) ) self.assertEqual(res2, "fco") - dev = yield self.handler.store.get_device("@boris:foo", "fco") + dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) self.assertEqual(dev["display_name"], "display name") - @defer.inlineCallbacks def test_device_id_is_made_up_if_unspecified(self): - device_id = yield self.handler.check_device_registered( - user_id="@theresa:foo", - device_id=None, - initial_device_display_name="display", + device_id = self.get_success( + self.handler.check_device_registered( + user_id="@theresa:foo", + device_id=None, + initial_device_display_name="display", + ) ) - dev = yield self.handler.store.get_device("@theresa:foo", device_id) + dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) self.assertEqual(dev["display_name"], "display") - @defer.inlineCallbacks def test_get_devices_by_user(self): - yield self._record_users() + self._record_users() + + res = self.get_success(self.handler.get_devices_by_user(user1)) - res = yield self.handler.get_devices_by_user(user1) self.assertEqual(3, len(res)) device_map = {d["device_id"]: d for d in res} self.assertDictContainsSubset( @@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase): device_map["abc"], ) - @defer.inlineCallbacks def test_get_device(self): - yield self._record_users() + self._record_users() - res = yield self.handler.get_device(user1, "abc") + res = self.get_success(self.handler.get_device(user1, "abc")) self.assertDictContainsSubset( { "user_id": user1, @@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase): res, ) - @defer.inlineCallbacks def test_delete_device(self): - yield self._record_users() + self._record_users() # delete the device - yield self.handler.delete_device(user1, "abc") + self.get_success(self.handler.delete_device(user1, "abc")) # check the device was deleted - with self.assertRaises(synapse.api.errors.NotFoundError): - yield self.handler.get_device(user1, "abc") + res = self.handler.get_device(user1, "abc") + self.pump() + self.assertIsInstance( + self.failureResultOf(res).value, synapse.api.errors.NotFoundError + ) # we'd like to check the access token was invalidated, but that's a # bit of a PITA. - @defer.inlineCallbacks def test_update_device(self): - yield self._record_users() + self._record_users() update = {"display_name": "new display"} - yield self.handler.update_device(user1, "abc", update) + self.get_success(self.handler.update_device(user1, "abc", update)) - res = yield self.handler.get_device(user1, "abc") + res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "new display") - @defer.inlineCallbacks def test_update_unknown_device(self): update = {"display_name": "new_display"} - with self.assertRaises(synapse.api.errors.NotFoundError): - yield self.handler.update_device("user_id", "unknown_device_id", update) + res = self.handler.update_device("user_id", "unknown_device_id", update) + self.pump() + self.assertIsInstance( + self.failureResultOf(res).value, synapse.api.errors.NotFoundError + ) - @defer.inlineCallbacks def _record_users(self): # check this works for both devices which have a recorded client_ip, # and those which don't. - yield self._record_user(user1, "xyz", "display 0") - yield self._record_user(user1, "fco", "display 1", "token1", "ip1") - yield self._record_user(user1, "abc", "display 2", "token2", "ip2") - yield self._record_user(user1, "abc", "display 2", "token3", "ip3") + self._record_user(user1, "xyz", "display 0") + self._record_user(user1, "fco", "display 1", "token1", "ip1") + self._record_user(user1, "abc", "display 2", "token2", "ip2") + self._record_user(user1, "abc", "display 2", "token3", "ip3") + + self._record_user(user2, "def", "dispkay", "token4", "ip4") - yield self._record_user(user2, "def", "dispkay", "token4", "ip4") + self.reactor.advance(10000) - @defer.inlineCallbacks def _record_user( self, user_id, device_id, display_name, access_token=None, ip=None ): - device_id = yield self.handler.check_device_registered( - user_id=user_id, - device_id=device_id, - initial_device_display_name=display_name, + device_id = self.get_success( + self.handler.check_device_registered( + user_id=user_id, + device_id=device_id, + initial_device_display_name=display_name, + ) ) if ip is not None: - yield self.store.insert_client_ip( - user_id, access_token, ip, "user_agent", device_id + self.get_success( + self.store.insert_client_ip( + user_id, access_token, ip, "user_agent", device_id + ) ) - self.clock.advance_time(1000) + self.reactor.advance(1000) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index ec7355688b..8ae6556c0a 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,7 +18,9 @@ from mock import Mock from twisted.internet import defer +from synapse.config.room_directory import RoomDirectoryConfig from synapse.handlers.directory import DirectoryHandler +from synapse.rest.client.v1 import directory, room from synapse.types import RoomAlias from tests import unittest @@ -102,3 +104,49 @@ class DirectoryTestCase(unittest.TestCase): ) self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) + + +class TestCreateAliasACL(unittest.HomeserverTestCase): + user_id = "@test:test" + + servlets = [directory.register_servlets, room.register_servlets] + + def prepare(self, hs, reactor, clock): + # We cheekily override the config to add custom alias creation rules + config = {} + config["alias_creation_rules"] = [ + { + "user_id": "*", + "alias": "#unofficial_*", + "action": "allow", + } + ] + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed + + return hs + + def test_denied(self): + room_id = self.helper.create_room_as(self.user_id) + + request, channel = self.make_request( + "PUT", + b"directory/room/%23test%3Atest", + ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) + + def test_allowed(self): + room_id = self.helper.create_room_as(self.user_id) + + request, channel = self.make_request( + "PUT", + b"directory/room/%23unofficial_test%3Atest", + ('{"room_id":"%s"}' % (room_id,)).encode('ascii'), + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py new file mode 100644 index 0000000000..9e08eac0a5 --- /dev/null +++ b/tests/handlers/test_e2e_room_keys.py @@ -0,0 +1,397 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import mock + +from twisted.internet import defer + +import synapse.api.errors +import synapse.handlers.e2e_room_keys +import synapse.storage +from synapse.api import errors + +from tests import unittest, utils + +# sample room_key data for use in the tests +room_keys = { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": False, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } +} + + +class E2eRoomKeysHandlerTestCase(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) + self.hs = None # type: synapse.server.HomeServer + self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler + + @defer.inlineCallbacks + def setUp(self): + self.hs = yield utils.setup_test_homeserver( + self.addCleanup, + handlers=None, + replication_layer=mock.Mock(), + ) + self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) + self.local_user = "@boris:" + self.hs.hostname + + @defer.inlineCallbacks + def test_get_missing_current_version_info(self): + """Check that we get a 404 if we ask for info about the current version + if there is no version. + """ + res = None + try: + yield self.handler.get_version_info(self.local_user) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_get_missing_version_info(self): + """Check that we get a 404 if we ask for info about a specific version + if it doesn't exist. + """ + res = None + try: + yield self.handler.get_version_info(self.local_user, "bogus_version") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_create_version(self): + """Check that we can create and then retrieve versions. + """ + res = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(res, "1") + + # check we can retrieve it as the current version + res = yield self.handler.get_version_info(self.local_user) + self.assertDictEqual(res, { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + + # check we can retrieve it as a specific version + res = yield self.handler.get_version_info(self.local_user, "1") + self.assertDictEqual(res, { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + + # upload a new one... + res = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }) + self.assertEqual(res, "2") + + # check we can retrieve it as the current version + res = yield self.handler.get_version_info(self.local_user) + self.assertDictEqual(res, { + "version": "2", + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }) + + @defer.inlineCallbacks + def test_delete_missing_version(self): + """Check that we get a 404 on deleting nonexistent versions + """ + res = None + try: + yield self.handler.delete_version(self.local_user, "1") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_delete_missing_current_version(self): + """Check that we get a 404 on deleting nonexistent current version + """ + res = None + try: + yield self.handler.delete_version(self.local_user) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_delete_version(self): + """Check that we can create and then delete versions. + """ + res = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(res, "1") + + # check we can delete it + yield self.handler.delete_version(self.local_user, "1") + + # check that it's gone + res = None + try: + yield self.handler.get_version_info(self.local_user, "1") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_get_missing_room_keys(self): + """Check that we get a 404 on querying missing room_keys + """ + res = None + try: + yield self.handler.get_room_keys(self.local_user, "bogus_version") + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # check we also get a 404 even if the version is valid + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + res = None + try: + yield self.handler.get_room_keys(self.local_user, version) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # TODO: test the locking semantics when uploading room_keys, + # although this is probably best done in sytest + + @defer.inlineCallbacks + def test_upload_room_keys_no_versions(self): + """Check that we get a 404 on uploading keys when no versions are defined + """ + res = None + try: + yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_upload_room_keys_bogus_version(self): + """Check that we get a 404 on uploading keys when an nonexistent version + is specified + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + res = None + try: + yield self.handler.upload_room_keys( + self.local_user, "bogus_version", room_keys + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_upload_room_keys_wrong_version(self): + """Check that we get a 403 on uploading keys for an old version + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }) + self.assertEqual(version, "2") + + res = None + try: + yield self.handler.upload_room_keys(self.local_user, "1", room_keys) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 403) + + @defer.inlineCallbacks + def test_upload_room_keys_insert(self): + """Check that we can insert and retrieve keys for a session + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertDictEqual(res, room_keys) + + # check getting room_keys for a given room + res = yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org" + ) + self.assertDictEqual(res, room_keys) + + # check getting room_keys for a given session_id + res = yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + self.assertDictEqual(res, room_keys) + + @defer.inlineCallbacks + def test_upload_room_keys_merge(self): + """Check that we can upload a new room_key for an existing session and + have it correctly merged""" + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + + new_room_keys = copy.deepcopy(room_keys) + new_room_key = new_room_keys['rooms']['!abc:matrix.org']['sessions']['c0ff33'] + + # test that increasing the message_index doesn't replace the existing session + new_room_key['first_message_index'] = 2 + new_room_key['session_data'] = 'new' + yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertEqual( + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + "SSBBTSBBIEZJU0gK" + ) + + # test that marking the session as verified however /does/ replace it + new_room_key['is_verified'] = True + yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertEqual( + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + "new" + ) + + # test that a session with a higher forwarded_count doesn't replace one + # with a lower forwarding count + new_room_key['forwarded_count'] = 2 + new_room_key['session_data'] = 'other' + yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) + + res = yield self.handler.get_room_keys(self.local_user, version) + self.assertEqual( + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], + "new" + ) + + # TODO: check edge cases as well as the common variations here + + @defer.inlineCallbacks + def test_delete_room_keys(self): + """Check that we can insert and delete keys for a session + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + # check for bulk-delete + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield self.handler.delete_room_keys(self.local_user, version) + res = None + try: + yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # check for bulk-delete per room + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield self.handler.delete_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + ) + res = None + try: + yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + # check for bulk-delete per session + yield self.handler.upload_room_keys(self.local_user, version, room_keys) + yield self.handler.delete_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + res = None + try: + yield self.handler.get_room_keys( + self.local_user, + version, + room_id="!abc:matrix.org", + session_id="c0ff33", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 40d9aca671..dc140570c6 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -20,7 +20,7 @@ from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError -from synapse.handlers.profile import ProfileHandler +from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest @@ -29,7 +29,7 @@ from tests.utils import setup_test_homeserver class ProfileHandlers(object): def __init__(self, hs): - self.profile_handler = ProfileHandler(hs) + self.profile_handler = MasterProfileHandler(hs) class ProfileTestCase(unittest.TestCase): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 7b4ade3dfb..3e9a190727 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api.errors import ResourceLimitError from synapse.handlers.register import RegistrationHandler -from synapse.types import UserID, create_requester +from synapse.types import RoomAlias, UserID, create_requester from tests.utils import setup_test_homeserver @@ -41,30 +41,27 @@ class RegistrationTestCase(unittest.TestCase): self.mock_captcha_client = Mock() self.hs = yield setup_test_homeserver( self.addCleanup, - handlers=None, - http_client=None, expire_access_token=True, - profile_handler=Mock(), ) self.macaroon_generator = Mock( generate_access_token=Mock(return_value='secret') ) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) - self.hs.handlers = RegistrationHandlers(self.hs) self.handler = self.hs.get_handlers().registration_handler self.store = self.hs.get_datastore() self.hs.config.max_mau_value = 50 self.lots_of_users = 100 self.small_number_of_users = 1 + self.requester = create_requester("@requester:test") + @defer.inlineCallbacks def test_user_is_created_and_logged_in_if_doesnt_exist(self): - local_part = "someone" - display_name = "someone" - user_id = "@someone:test" - requester = create_requester("@as:test") + frank = UserID.from_string("@frank:test") + user_id = frank.to_string() + requester = create_requester(user_id) result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name + requester, frank.localpart, "Frankie" ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -78,12 +75,11 @@ class RegistrationTestCase(unittest.TestCase): token="jkv;g498752-43gj['eamb!-5", password_hash=None, ) - local_part = "frank" - display_name = "Frank" - user_id = "@frank:test" - requester = create_requester("@as:test") + local_part = frank.localpart + user_id = frank.to_string() + requester = create_requester(user_id) result_user_id, result_token = yield self.handler.get_or_create_user( - requester, local_part, display_name + requester, local_part, None ) self.assertEquals(result_user_id, user_id) self.assertEquals(result_token, 'secret') @@ -92,7 +88,7 @@ class RegistrationTestCase(unittest.TestCase): def test_mau_limits_when_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.handler.get_or_create_user("requester", 'a', "display_name") + yield self.handler.get_or_create_user(self.requester, 'a', "display_name") @defer.inlineCallbacks def test_get_or_create_user_mau_not_blocked(self): @@ -101,7 +97,7 @@ class RegistrationTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception - yield self.handler.get_or_create_user("@user:server", 'c', "User") + yield self.handler.get_or_create_user(self.requester, 'c', "User") @defer.inlineCallbacks def test_get_or_create_user_mau_blocked(self): @@ -110,13 +106,13 @@ class RegistrationTestCase(unittest.TestCase): return_value=defer.succeed(self.lots_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user("requester", 'b', "display_name") + yield self.handler.get_or_create_user(self.requester, 'b', "display_name") self.store.get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.handler.get_or_create_user("requester", 'b', "display_name") + yield self.handler.get_or_create_user(self.requester, 'b', "display_name") @defer.inlineCallbacks def test_register_mau_blocked(self): @@ -147,3 +143,44 @@ class RegistrationTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): yield self.handler.register_saml2(localpart="local_part") + + @defer.inlineCallbacks + def test_auto_create_auto_join_rooms(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + res = yield self.handler.register(localpart='jeff') + rooms = yield self.store.get_rooms_for_user(res[0]) + + directory_handler = self.hs.get_handlers().directory_handler + room_alias = RoomAlias.from_string(room_alias_str) + room_id = yield directory_handler.get_association(room_alias) + + self.assertTrue(room_id['room_id'] in rooms) + self.assertEqual(len(rooms), 1) + + @defer.inlineCallbacks + def test_auto_create_auto_join_rooms_with_no_rooms(self): + self.hs.config.auto_join_rooms = [] + frank = UserID.from_string("@frank:test") + res = yield self.handler.register(frank.localpart) + self.assertEqual(res[0], frank.to_string()) + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) + + @defer.inlineCallbacks + def test_auto_create_auto_join_where_room_is_another_domain(self): + self.hs.config.auto_join_rooms = ["#room:another"] + frank = UserID.from_string("@frank:test") + res = yield self.handler.register(frank.localpart) + self.assertEqual(res[0], frank.to_string()) + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) + + @defer.inlineCallbacks + def test_auto_create_auto_join_where_auto_create_is_false(self): + self.hs.config.autocreate_auto_join_rooms = False + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + res = yield self.handler.register(localpart='jeff') + rooms = yield self.store.get_rooms_for_user(res[0]) + self.assertEqual(len(rooms), 0) diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py new file mode 100644 index 0000000000..61eebb6985 --- /dev/null +++ b/tests/handlers/test_roomlist.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.handlers.room_list import RoomListNextBatch + +import tests.unittest +import tests.utils + + +class RoomListTestCase(tests.unittest.TestCase): + """ Tests RoomList's RoomListNextBatch. """ + + def setUp(self): + pass + + def test_check_read_batch_tokens(self): + batch_token = RoomListNextBatch( + stream_ordering="abcdef", + public_room_stream_id="123", + current_limit=20, + direction_is_forward=True, + ).to_token() + next_batch = RoomListNextBatch.from_token(batch_token) + self.assertEquals(next_batch.stream_ordering, "abcdef") + self.assertEquals(next_batch.public_room_stream_id, "123") + self.assertEquals(next_batch.current_limit, 20) + self.assertEquals(next_batch.direction_is_forward, True) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index a01ab471f5..31f54bbd7d 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -51,7 +51,7 @@ class SyncTestCase(tests.unittest.TestCase): self.hs.config.hs_disabled = True with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.hs.config.hs_disabled = False @@ -59,7 +59,7 @@ class SyncTestCase(tests.unittest.TestCase): with self.assertRaises(ResourceLimitError) as e: yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED) + self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def _generate_sync_config(self, user_id): return SyncConfig( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ad58073a14..36e136cded 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -33,7 +33,7 @@ from ..utils import ( ) -def _expect_edu(destination, edu_type, content, origin="test"): +def _expect_edu_transaction(edu_type, content, origin="test"): return { "origin": origin, "origin_server_ts": 1000000, @@ -42,10 +42,8 @@ def _expect_edu(destination, edu_type, content, origin="test"): } -def _make_edu_json(origin, edu_type, content): - return json.dumps(_expect_edu("test", edu_type, content, origin=origin)).encode( - 'utf8' - ) +def _make_edu_transaction_json(edu_type, content): + return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8') class TypingNotificationsTestCase(unittest.TestCase): @@ -190,8 +188,7 @@ class TypingNotificationsTestCase(unittest.TestCase): call( "farm", path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu( - "farm", + data=_expect_edu_transaction( "m.typing", content={ "room_id": self.room_id, @@ -221,11 +218,10 @@ class TypingNotificationsTestCase(unittest.TestCase): self.assertEquals(self.event_source.get_current_key(), 0) - yield self.mock_federation_resource.trigger( + (code, response) = yield self.mock_federation_resource.trigger( "PUT", "/_matrix/federation/v1/send/1000000/", - _make_edu_json( - "farm", + _make_edu_transaction_json( "m.typing", content={ "room_id": self.room_id, @@ -233,7 +229,7 @@ class TypingNotificationsTestCase(unittest.TestCase): "typing": True, }, ), - federation_auth=True, + federation_auth_origin=b'farm', ) self.on_new_event.assert_has_calls( @@ -264,8 +260,7 @@ class TypingNotificationsTestCase(unittest.TestCase): call( "farm", path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu( - "farm", + data=_expect_edu_transaction( "m.typing", content={ "room_id": self.room_id, diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py new file mode 100644 index 0000000000..f3cb1423f0 --- /dev/null +++ b/tests/http/test_fedclient.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet.defer import TimeoutError +from twisted.internet.error import ConnectingCancelledError, DNSLookupError +from twisted.web.client import ResponseNeverReceived +from twisted.web.http import HTTPChannel + +from synapse.http.matrixfederationclient import ( + MatrixFederationHttpClient, + MatrixFederationRequest, +) + +from tests.server import FakeTransport +from tests.unittest import HomeserverTestCase + + +class FederationClientTests(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver(reactor=reactor, clock=clock) + hs.tls_client_options_factory = None + return hs + + def prepare(self, reactor, clock, homeserver): + + self.cl = MatrixFederationHttpClient(self.hs) + self.reactor.lookups["testserv"] = "1.2.3.4" + + def test_dns_error(self): + """ + If the DNS raising returns an error, it will bubble up. + """ + d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + self.pump() + + f = self.failureResultOf(d) + self.assertIsInstance(f.value, DNSLookupError) + + def test_client_never_connect(self): + """ + If the HTTP request is not connected and is timed out, it'll give a + ConnectingCancelledError or TimeoutError. + """ + d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + + self.pump() + + # Nothing happened yet + self.assertFalse(d.called) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][1], 8008) + + # Deferred is still without a result + self.assertFalse(d.called) + + # Push by enough to time it out + self.reactor.advance(10.5) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, (ConnectingCancelledError, TimeoutError)) + + def test_client_connect_no_response(self): + """ + If the HTTP request is connected, but gets no response before being + timed out, it'll give a ResponseNeverReceived. + """ + d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + + self.pump() + + # Nothing happened yet + self.assertFalse(d.called) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + self.assertEqual(clients[0][0], '1.2.3.4') + self.assertEqual(clients[0][1], 8008) + + conn = Mock() + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred is still without a result + self.assertFalse(d.called) + + # Push by enough to time it out + self.reactor.advance(10.5) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, ResponseNeverReceived) + + def test_client_gets_headers(self): + """ + Once the client gets the headers, _request returns successfully. + """ + request = MatrixFederationRequest( + method="GET", + destination="testserv:8008", + path="foo/bar", + ) + d = self.cl._send_request(request, timeout=10000) + + self.pump() + + conn = Mock() + clients = self.reactor.tcpClients + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred does not have a result + self.assertFalse(d.called) + + # Send it the HTTP response + client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n") + + # We should get a successful response + r = self.successResultOf(d) + self.assertEqual(r.code, 200) + + def test_client_headers_no_body(self): + """ + If the HTTP request is connected, but gets no response before being + timed out, it'll give a ResponseNeverReceived. + """ + d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + + self.pump() + + conn = Mock() + clients = self.reactor.tcpClients + client = clients[0][2].buildProtocol(None) + client.makeConnection(conn) + + # Deferred does not have a result + self.assertFalse(d.called) + + # Send it the HTTP response + client.dataReceived( + (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n") + ) + + # Push by enough to time it out + self.reactor.advance(10.5) + f = self.failureResultOf(d) + + self.assertIsInstance(f.value, TimeoutError) + + def test_client_sends_body(self): + self.cl.post_json( + "testserv:8008", "foo/bar", timeout=10000, + data={"a": "b"} + ) + + self.pump() + + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + client = clients[0][2].buildProtocol(None) + server = HTTPChannel() + + client.makeConnection(FakeTransport(server, self.reactor)) + server.makeConnection(FakeTransport(client, self.reactor)) + + self.pump(0.1) + + self.assertEqual(len(server.requests), 1) + request = server.requests[0] + content = request.content.read() + self.assertEqual(content, b'{"a":"b"}') diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 65df116efc..9e9fbbfe93 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,89 +12,62 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import tempfile from mock import Mock, NonCallableMock -from twisted.internet import defer, reactor -from twisted.internet.defer import Deferred - from synapse.replication.tcp.client import ( ReplicationClientFactory, ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable from tests import unittest -from tests.utils import setup_test_homeserver - - -class TestReplicationClientHandler(ReplicationClientHandler): - """Overrides on_rdata so that we can wait for it to happen""" - - def __init__(self, store): - super(TestReplicationClientHandler, self).__init__(store) - self._rdata_awaiters = [] - - def await_replication(self): - d = Deferred() - self._rdata_awaiters.append(d) - return make_deferred_yieldable(d) +from tests.server import FakeTransport - def on_rdata(self, stream_name, token, rows): - awaiters = self._rdata_awaiters - self._rdata_awaiters = [] - super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows) - with PreserveLoggingContext(): - for a in awaiters: - a.callback(None) +class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): -class BaseSlavedStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver( - self.addCleanup, + hs = self.setup_test_homeserver( "blue", - http_client=None, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) - self.hs.get_ratelimiter().send_message.return_value = (True, 0) + + hs.get_ratelimiter().send_message.return_value = (True, 0) + + return hs + + def prepare(self, reactor, clock, hs): self.master_store = self.hs.get_datastore() self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - # XXX: mktemp is unsafe and should never be used. but we're just a test. - path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket") - listener = reactor.listenUNIX(path, server_factory) - self.addCleanup(listener.stopListening) self.streamer = server_factory.streamer - self.replication_handler = TestReplicationClientHandler(self.slaved_store) + self.replication_handler = ReplicationClientHandler(self.slaved_store) client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) - client_connector = reactor.connectUNIX(path, client_factory) - self.addCleanup(client_factory.stopTrying) - self.addCleanup(client_connector.disconnect) + + server = server_factory.buildProtocol(None) + client = client_factory.buildProtocol(None) + + client.makeConnection(FakeTransport(server, reactor)) + server.makeConnection(FakeTransport(client, reactor)) def replicate(self): """Tell the master side of replication that something has happened, and then wait for the replication to occur. """ - # xxx: should we be more specific in what we wait for? - d = self.replication_handler.await_replication() self.streamer.on_notifier_poke() - return d + self.pump(0.1) - @defer.inlineCallbacks def check(self, method, args, expected_result=None): - master_result = yield getattr(self.master_store, method)(*args) - slaved_result = yield getattr(self.slaved_store, method)(*args) + master_result = self.get_success(getattr(self.master_store, method)(*args)) + slaved_result = self.get_success(getattr(self.slaved_store, method)(*args)) if expected_result is not None: self.assertEqual(master_result, expected_result) self.assertEqual(slaved_result, expected_result) diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py index 87cc2b2fba..43e3248703 100644 --- a/tests/replication/slave/storage/test_account_data.py +++ b/tests/replication/slave/storage/test_account_data.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from ._base import BaseSlavedStoreTestCase @@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): STORE_TYPE = SlavedAccountDataStore - @defer.inlineCallbacks def test_user_account_data(self): - yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) - yield self.replicate() - yield self.check( + self.get_success( + self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) + ) + self.replicate() + self.check( "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} ) - yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) - yield self.replicate() - yield self.check( + self.get_success( + self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) + ) + self.replicate() + self.check( "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 622be2eef8..41be5d5a1a 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from canonicaljson import encode_canonical_json from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext @@ -28,7 +28,9 @@ ROOM_ID = "!room:blue" def dict_equals(self, other): - return self.__dict__ == other.__dict__ + me = encode_canonical_json(self._event_dict) + them = encode_canonical_json(other._event_dict) + return me == them def patch__eq__(cls): @@ -55,69 +57,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): def tearDown(self): [unpatch() for unpatch in self.unpatches] - @defer.inlineCallbacks def test_get_latest_event_ids_in_room(self): - create = yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.replicate() - yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) + create = self.persist(type="m.room.create", key="", creator=USER_ID) + self.replicate() + self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) - join = yield self.persist( + join = self.persist( type="m.room.member", key=USER_ID, membership="join", prev_events=[(create.event_id, {})], ) - yield self.replicate() - yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) + self.replicate() + self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) - @defer.inlineCallbacks def test_redactions(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.persist(type="m.room.member", key=USER_ID, membership="join") + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") - yield self.replicate() - yield self.check("get_event", [msg.event_id], msg) + msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") + self.replicate() + self.check("get_event", [msg.event_id], msg) - redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id) - yield self.replicate() + redaction = self.persist(type="m.room.redaction", redacts=msg.event_id) + self.replicate() msg_dict = msg.get_dict() msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) - yield self.check("get_event", [msg.event_id], redacted) + self.check("get_event", [msg.event_id], redacted) - @defer.inlineCallbacks def test_backfilled_redactions(self): - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.persist(type="m.room.member", key=USER_ID, membership="join") + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.member", key=USER_ID, membership="join") - msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") - yield self.replicate() - yield self.check("get_event", [msg.event_id], msg) + msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") + self.replicate() + self.check("get_event", [msg.event_id], msg) - redaction = yield self.persist( + redaction = self.persist( type="m.room.redaction", redacts=msg.event_id, backfill=True ) - yield self.replicate() + self.replicate() msg_dict = msg.get_dict() msg_dict["content"] = {} msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_because"] = redaction redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) - yield self.check("get_event", [msg.event_id], redacted) + self.check("get_event", [msg.event_id], redacted) - @defer.inlineCallbacks def test_invites(self): - yield self.check("get_invited_rooms_for_user", [USER_ID_2], []) - event = yield self.persist( - type="m.room.member", key=USER_ID_2, membership="invite" - ) - yield self.replicate() - yield self.check( + self.persist(type="m.room.create", key="", creator=USER_ID) + self.check("get_invited_rooms_for_user", [USER_ID_2], []) + event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") + + self.replicate() + + self.check( "get_invited_rooms_for_user", [USER_ID_2], [ @@ -131,37 +130,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ], ) - @defer.inlineCallbacks def test_push_actions_for_user(self): - yield self.persist(type="m.room.create", creator=USER_ID) - yield self.persist(type="m.room.join", key=USER_ID, membership="join") - yield self.persist( + self.persist(type="m.room.create", key="", creator=USER_ID) + self.persist(type="m.room.join", key=USER_ID, membership="join") + self.persist( type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" ) - event1 = yield self.persist( - type="m.room.message", msgtype="m.text", body="hello" - ) - yield self.replicate() - yield self.check( + event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello") + self.replicate() + self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], {"highlight_count": 0, "notify_count": 0}, ) - yield self.persist( + self.persist( type="m.room.message", msgtype="m.text", body="world", push_actions=[(USER_ID_2, ["notify"])], ) - yield self.replicate() - yield self.check( + self.replicate() + self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], {"highlight_count": 0, "notify_count": 1}, ) - yield self.persist( + self.persist( type="m.room.message", msgtype="m.text", body="world", @@ -169,8 +165,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}]) ], ) - yield self.replicate() - yield self.check( + self.replicate() + self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2, event1.event_id], {"highlight_count": 1, "notify_count": 2}, @@ -178,7 +174,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_id = 0 - @defer.inlineCallbacks def persist( self, sender=USER_ID, @@ -205,8 +200,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): depth = self.event_id if not prev_events: - latest_event_ids = yield self.master_store.get_latest_event_ids_in_room( - room_id + latest_event_ids = self.get_success( + self.master_store.get_latest_event_ids_in_room(room_id) ) prev_events = [(ev_id, {}) for ev_id in latest_event_ids] @@ -239,19 +234,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) else: state_handler = self.hs.get_state_handler() - context = yield state_handler.compute_event_context(event) + context = self.get_success(state_handler.compute_event_context(event)) - yield self.master_store.add_push_actions_to_staging( + self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} ) ordering = None if backfill: - yield self.master_store.persist_events([(event, context)], backfilled=True) + self.get_success( + self.master_store.persist_events([(event, context)], backfilled=True) + ) else: - ordering, _ = yield self.master_store.persist_event(event, context) + ordering, _ = self.get_success( + self.master_store.persist_event(event, context) + ) if ordering: event.internal_metadata.stream_ordering = ordering - defer.returnValue(event) + return event diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index ae1adeded1..f47d94f690 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from ._base import BaseSlavedStoreTestCase @@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): STORE_TYPE = SlavedReceiptsStore - @defer.inlineCallbacks def test_receipt(self): - yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) - yield self.master_store.insert_receipt( - ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} - ) - yield self.replicate() - yield self.check( - "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID} + self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) + self.get_success( + self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}) ) + self.replicate() + self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 9fe0760496..a824be9a62 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -22,39 +22,24 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer -import synapse.rest.client.v1.room from synapse.api.constants import Membership -from synapse.http.server import JsonResource -from synapse.types import UserID -from synapse.util import Clock +from synapse.rest.client.v1 import admin, login, room from tests import unittest -from tests.server import ( - ThreadedMemoryReactorClock, - make_request, - render, - setup_test_homeserver, -) - -from .utils import RestHelper PATH_PREFIX = b"/_matrix/client/api/v1" -class RoomBase(unittest.TestCase): +class RoomBase(unittest.HomeserverTestCase): rmcreator_id = None - def setUp(self): + servlets = [room.register_servlets, room.register_deprecated_servlets] - self.clock = ThreadedMemoryReactorClock() - self.hs_clock = Clock(self.clock) + def make_homeserver(self, reactor, clock): - self.hs = setup_test_homeserver( - self.addCleanup, + self.hs = self.setup_test_homeserver( "red", http_client=None, - clock=self.hs_clock, - reactor=self.clock, federation_client=Mock(), ratelimiter=NonCallableMock(spec_set=["send_message"]), ) @@ -63,42 +48,21 @@ class RoomBase(unittest.TestCase): self.hs.get_federation_handler = Mock(return_value=Mock()) - def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } - - def get_user_by_req(request, allow_guest=False, rights="access"): - return synapse.types.create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None - ) - - self.hs.get_auth().get_user_by_req = get_user_by_req - self.hs.get_auth().get_user_by_access_token = get_user_by_access_token - self.hs.get_auth().get_access_token_from_request = Mock(return_value=b"1234") - def _insert_client_ip(*args, **kwargs): return defer.succeed(None) self.hs.get_datastore().insert_client_ip = _insert_client_ip - self.resource = JsonResource(self.hs) - synapse.rest.client.v1.room.register_servlets(self.hs, self.resource) - synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource) - self.helper = RestHelper(self.hs, self.resource, self.user_id) + return self.hs class RoomPermissionsTestCase(RoomBase): """ Tests room permissions. """ - user_id = b"@sid1:red" - rmcreator_id = b"@notme:red" - - def setUp(self): + user_id = "@sid1:red" + rmcreator_id = "@notme:red" - super(RoomPermissionsTestCase, self).setUp() + def prepare(self, reactor, clock, hs): self.helper.auth_user_id = self.rmcreator_id # create some rooms under the name rmcreator_id @@ -114,22 +78,20 @@ class RoomPermissionsTestCase(RoomBase): self.created_rmid_msg_path = ( "rooms/%s/send/m.room.message/a1" % (self.created_rmid) ).encode('ascii') - request, channel = make_request( - b"PUT", - self.created_rmid_msg_path, - b'{"msgtype":"m.text","body":"test msg"}', + request, channel = self.make_request( + "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - render(request, self.resource, self.clock) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.render(request) + self.assertEquals(200, channel.code, channel.result) # set topic for public room - request, channel = make_request( - b"PUT", + request, channel = self.make_request( + "PUT", ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode('ascii'), b'{"topic":"Public Room Topic"}', ) - render(request, self.resource, self.clock) - self.assertEquals(channel.result["code"], b"200", channel.result) + self.render(request) + self.assertEquals(200, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -140,128 +102,128 @@ class RoomPermissionsTestCase(RoomBase): seq = iter(range(100)) def send_msg_path(): - return b"/rooms/%s/send/m.room.message/mid%s" % ( + return "/rooms/%s/send/m.room.message/mid%s" % ( self.created_rmid, - str(next(seq)).encode('ascii'), + str(next(seq)), ) # send message in uncreated room, expect 403 - request, channel = make_request( - b"PUT", - b"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + request, channel = self.make_request( + "PUT", + "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = make_request(b"PUT", send_msg_path(), msg_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", send_msg_path(), msg_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_topic_perms(self): topic_content = b'{"topic":"My Topic Name"}' - topic_path = b"/rooms/%s/state/m.room.topic" % self.created_rmid + topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid # set/get topic in uncreated room, expect 403 - request, channel = make_request( - b"PUT", b"/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content + request, channel = self.make_request( + "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - request, channel = make_request( - b"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + request, channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assert_dict(json.loads(topic_content), channel.json_body) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(topic_content.decode('utf8')), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) - request, channel = make_request(b"PUT", topic_path, topic_content) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) - request, channel = make_request(b"GET", topic_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", topic_path, topic_content) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + request, channel = self.make_request("GET", topic_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 - request, channel = make_request( - b"GET", b"/rooms/%s/state/m.room.topic" % self.created_public_rmid + request, channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 - request, channel = make_request( - b"PUT", - b"/rooms/%s/state/m.room.topic" % self.created_public_rmid, + request, channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def _test_get_membership(self, room=None, members=[], expect_code=None): for member in members: - path = b"/rooms/%s/state/m.room.member/%s" % (room, member) - request, channel = make_request(b"GET", path) - render(request, self.resource, self.clock) - self.assertEquals(expect_code, int(channel.result["code"])) + path = "/rooms/%s/state/m.room.member/%s" % (room, member) + request, channel = self.make_request("GET", path) + self.render(request) + self.assertEquals(expect_code, channel.code) def test_membership_basic_room_perms(self): # === room does not exist === @@ -428,217 +390,211 @@ class RoomPermissionsTestCase(RoomBase): class RoomsMemberListTestCase(RoomBase): """ Tests /rooms/$room_id/members/list REST events.""" - user_id = b"@sid1:red" + user_id = "@sid1:red" def test_get_member_list(self): room_id = self.helper.create_room_as(self.user_id) - request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self): - request, channel = make_request(b"GET", b"/rooms/roomdoesnotexist/members") - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self): - room_id = self.helper.create_room_as(b"@some_other_guy:red") - request, channel = make_request(b"GET", b"/rooms/%s/members" % room_id) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + room_id = self.helper.create_room_as("@some_other_guy:red") + request, channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self): - room_creator = b"@some_other_guy:red" + room_creator = "@some_other_guy:red" room_id = self.helper.create_room_as(room_creator) - room_path = b"/rooms/%s/members" % room_id + room_path = "/rooms/%s/members" % room_id self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. - request, channel = make_request(b"GET", room_path) - render(request, self.resource, self.clock) - self.assertEquals(403, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", room_path) + self.render(request) + self.assertEquals(403, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined - request, channel = make_request(b"GET", room_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", room_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left - request, channel = make_request(b"GET", room_path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", room_path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) class RoomsCreateTestCase(RoomBase): """ Tests /rooms and /rooms/$room_id REST events. """ - user_id = b"@sid1:red" + user_id = "@sid1:red" def test_post_room_no_keys(self): # POST with no config keys, expect new room id - request, channel = make_request(b"POST", b"/createRoom", b"{}") + request, channel = self.make_request("POST", "/createRoom", "{}") - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), channel.result) + self.render(request) + self.assertEquals(200, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) def test_post_room_visibility_key(self): # POST with visibility config key, expect new room id - request, channel = make_request( - b"POST", b"/createRoom", b'{"visibility":"private"}' + request, channel = self.make_request( + "POST", "/createRoom", b'{"visibility":"private"}' ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self): # POST with custom config keys, expect new room id - request, channel = make_request(b"POST", b"/createRoom", b'{"custom":"stuff"}') - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + request, channel = self.make_request( + "POST", "/createRoom", b'{"custom":"stuff"}' + ) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self): # POST with custom + known config keys, expect new room id - request, channel = make_request( - b"POST", b"/createRoom", b'{"visibility":"private","custom":"things"}' + request, channel = self.make_request( + "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self): # POST with invalid content / paths, expect 400 - request, channel = make_request(b"POST", b"/createRoom", b'{"visibili') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"])) + request, channel = self.make_request("POST", "/createRoom", b'{"visibili') + self.render(request) + self.assertEquals(400, channel.code) - request, channel = make_request(b"POST", b"/createRoom", b'["hello"]') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"])) + request, channel = self.make_request("POST", "/createRoom", b'["hello"]') + self.render(request) + self.assertEquals(400, channel.code) class RoomTopicTestCase(RoomBase): """ Tests /rooms/$room_id/topic REST events. """ - user_id = b"@sid1:red" - - def setUp(self): - - super(RoomTopicTestCase, self).setUp() + user_id = "@sid1:red" + def prepare(self, reactor, clock, hs): # create the room self.room_id = self.helper.create_room_as(self.user_id) - self.path = b"/rooms/%s/state/m.room.topic" % (self.room_id,) + self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) def test_invalid_puts(self): # missing keys or invalid json - request, channel = make_request(b"PUT", self.path, '{}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '{}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, '{"_name":"bob"}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, '{"nao') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '{"nao') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request( - b"PUT", self.path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = self.make_request( + "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, 'text only') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, 'text only') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", self.path, '') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, '') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid key, wrong type content = '{"topic":["Topic name"]}' - request, channel = make_request(b"PUT", self.path, content) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, content) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_topic(self): # nothing should be there - request, channel = make_request(b"GET", self.path) - render(request, self.resource, self.clock) - self.assertEquals(404, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", self.path) + self.render(request) + self.assertEquals(404, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' - request, channel = make_request(b"PUT", self.path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = make_request(b"GET", self.path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", self.path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self): # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' - request, channel = make_request(b"PUT", self.path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", self.path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # valid get - request, channel = make_request(b"GET", self.path) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", self.path) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) class RoomMemberStateTestCase(RoomBase): """ Tests /rooms/$room_id/members/$user_id/state REST events. """ - user_id = b"@sid1:red" - - def setUp(self): + user_id = "@sid1:red" - super(RoomMemberStateTestCase, self).setUp() + def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) - def tearDown(self): - pass - def test_invalid_puts(self): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json - request, channel = make_request(b"PUT", path, '{}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '{}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"_name":"bob"}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '{"_name":"bo"}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"nao') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '{"nao') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request( - b"PUT", path, b'[{"_name":"bob"},{"_name":"jill"}]' + request, channel = self.make_request( + "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' ) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, 'text only') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, 'text only') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, '') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -646,9 +602,9 @@ class RoomMemberStateTestCase(RoomBase): Membership.JOIN, Membership.LEAVE, ) - request, channel = make_request(b"PUT", path, content.encode('ascii')) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content.encode('ascii')) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( @@ -658,13 +614,13 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN - request, channel = make_request(b"PUT", path, content.encode('ascii')) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content.encode('ascii')) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"GET", path, None) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", path, None) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEquals(expected_response, channel.json_body) @@ -678,13 +634,13 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"GET", path, None) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", path, None) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self): @@ -699,13 +655,13 @@ class RoomMemberStateTestCase(RoomBase): Membership.INVITE, "Join us!", ) - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"GET", path, None) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("GET", path, None) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) self.assertEquals(json.loads(content), channel.json_body) @@ -714,60 +670,58 @@ class RoomMessagesTestCase(RoomBase): user_id = "@sid1:red" - def setUp(self): - super(RoomMessagesTestCase, self).setUp() - + def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json - request, channel = make_request(b"PUT", path, '{}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'{}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"_name":"bob"}') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'{"_name":"bo"}') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '{"nao') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'{"nao') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request( - b"PUT", path, '[{"_name":"bob"},{"_name":"jill"}]' + request, channel = self.make_request( + "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]' ) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, 'text only') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'text only') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) - request, channel = make_request(b"PUT", path, '') - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + request, channel = self.make_request("PUT", path, b'') + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) - content = '{"body":"test","msgtype":{"type":"a"}}' - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(400, int(channel.result["code"]), msg=channel.result["body"]) + content = b'{"body":"test","msgtype":{"type":"a"}}' + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(400, channel.code, msg=channel.result["body"]) # custom message types - content = '{"body":"test","msgtype":"test.custom.text"}' - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + content = b'{"body":"test","msgtype":"test.custom.text"}' + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) - content = '{"body":"test2","msgtype":"m.text"}' - request, channel = make_request(b"PUT", path, content) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + content = b'{"body":"test2","msgtype":"m.text"}' + request, channel = self.make_request("PUT", path, content) + self.render(request) + self.assertEquals(200, channel.code, msg=channel.result["body"]) class RoomInitialSyncTestCase(RoomBase): @@ -775,16 +729,16 @@ class RoomInitialSyncTestCase(RoomBase): user_id = "@sid1:red" - def setUp(self): - super(RoomInitialSyncTestCase, self).setUp() - + def prepare(self, reactor, clock, hs): # create the room self.room_id = self.helper.create_room_as(self.user_id) def test_initial_sync(self): - request, channel = make_request(b"GET", "/rooms/%s/initialSync" % self.room_id) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + request, channel = self.make_request( + "GET", "/rooms/%s/initialSync" % self.room_id + ) + self.render(request) + self.assertEquals(200, channel.code) self.assertEquals(self.room_id, channel.json_body["room_id"]) self.assertEquals("join", channel.json_body["membership"]) @@ -819,17 +773,16 @@ class RoomMessageListTestCase(RoomBase): user_id = "@sid1:red" - def setUp(self): - super(RoomMessageListTestCase, self).setUp() + def prepare(self, reactor, clock, hs): self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self): token = "t1-0_0_0_0_0_0_0_0_0" - request, channel = make_request( - b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + request, channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body['start']) self.assertTrue("chunk" in channel.json_body) @@ -837,12 +790,116 @@ class RoomMessageListTestCase(RoomBase): def test_stream_token_is_accepted_for_fwd_pagianation(self): token = "s0_0_0_0_0_0_0_0_0" - request, channel = make_request( - b"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + request, channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - render(request, self.resource, self.clock) - self.assertEquals(200, int(channel.result["code"])) + self.render(request) + self.assertEquals(200, channel.code) self.assertTrue("start" in channel.json_body) self.assertEquals(token, channel.json_body['start']) self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) + + +class RoomSearchTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def prepare(self, reactor, clock, hs): + + # Register the user who does the searching + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + # Register the user who sends the message + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + # Create a room + self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + + # Invite the other person + self.helper.invite( + room=self.room, + src=self.user_id, + tok=self.access_token, + targ=self.other_user_id, + ) + + # The other user joins + self.helper.join( + room=self.room, user=self.other_user_id, tok=self.other_access_token + ) + + def test_finds_message(self): + """ + The search functionality will search for content in messages if asked to + do so. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + request, channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": {"keys": ["content.body"], "search_term": "Hi"} + } + }, + ) + self.render(request) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # No context was requested, so we should get none. + self.assertEqual(results["results"][0]["context"], {}) + + def test_include_context(self): + """ + When event_context includes include_profile, profile information will be + included in the search response. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + request, channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": { + "keys": ["content.body"], + "search_term": "Hi", + "event_context": {"include_profile": True}, + } + } + }, + ) + self.render(request) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # We should get context info, like the two users, and the display names. + context = results["results"][0]["context"] + self.assertEqual(len(context["profile_info"].keys()), 2) + self.assertEqual( + context["profile_info"][self.other_user_id]["displayname"], "otheruser" + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 40dc4ea256..530dc8ba6d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -240,7 +240,6 @@ class RestHelper(object): self.assertEquals(200, code) defer.returnValue(response) - @defer.inlineCallbacks def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -248,9 +247,16 @@ class RestHelper(object): body = "body_text_here" path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id) - content = '{"msgtype":"m.text","body":"%s"}' % body + content = {"msgtype": "m.text", "body": body} if tok: path = path + "?access_token=%s" % tok - (code, response) = yield self.mock_resource.trigger("PUT", path, content) - self.assertEquals(expect_code, code, msg=str(response)) + request, channel = make_request("PUT", path, json.dumps(content).encode('utf8')) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 560b1fba96..4c30c5f258 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -62,12 +62,6 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( set( - [ - "next_batch", - "rooms", - "account_data", - "to_device", - "device_lists", - ] + ["next_batch", "rooms", "account_data", "to_device", "device_lists"] ).issubset(set(channel.json_body.keys())) ) diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/scripts/__init__.py diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py new file mode 100644 index 0000000000..6f56893f5e --- /dev/null +++ b/tests/scripts/test_new_matrix_user.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from synapse._scripts.register_new_matrix_user import request_registration + +from tests.unittest import TestCase + + +class RegisterTestCase(TestCase): + def test_success(self): + """ + The script will fetch a nonce, and then generate a MAC with it, and then + post that MAC. + """ + + def get(url, verify=None): + r = Mock() + r.status_code = 200 + r.json = lambda: {"nonce": "a"} + return r + + def post(url, json=None, verify=None): + # Make sure we are sent the correct info + self.assertEqual(json["username"], "user") + self.assertEqual(json["password"], "pass") + self.assertEqual(json["nonce"], "a") + # We want a 40-char hex MAC + self.assertEqual(len(json["mac"]), 40) + + r = Mock() + r.status_code = 200 + return r + + requests = Mock() + requests.get = get + requests.post = post + + # The fake stdout will be written here + out = [] + err_code = [] + + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + requests=requests, + _print=out.append, + exit=err_code.append, + ) + + # We should get the success message making sure everything is OK. + self.assertIn("Success!", out) + + # sys.exit shouldn't have been called. + self.assertEqual(err_code, []) + + def test_failure_nonce(self): + """ + If the script fails to fetch a nonce, it throws an error and quits. + """ + + def get(url, verify=None): + r = Mock() + r.status_code = 404 + r.reason = "Not Found" + r.json = lambda: {"not": "error"} + return r + + requests = Mock() + requests.get = get + + # The fake stdout will be written here + out = [] + err_code = [] + + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + requests=requests, + _print=out.append, + exit=err_code.append, + ) + + # Exit was called + self.assertEqual(err_code, [1]) + + # We got an error message + self.assertIn("ERROR! Received 404 Not Found", out) + self.assertNotIn("Success!", out) + + def test_failure_post(self): + """ + The script will fetch a nonce, and then if the final POST fails, will + report an error and quit. + """ + + def get(url, verify=None): + r = Mock() + r.status_code = 200 + r.json = lambda: {"nonce": "a"} + return r + + def post(url, json=None, verify=None): + # Make sure we are sent the correct info + self.assertEqual(json["username"], "user") + self.assertEqual(json["password"], "pass") + self.assertEqual(json["nonce"], "a") + # We want a 40-char hex MAC + self.assertEqual(len(json["mac"]), 40) + + r = Mock() + # Then 500 because we're jerks + r.status_code = 500 + r.reason = "Broken" + return r + + requests = Mock() + requests.get = get + requests.post = post + + # The fake stdout will be written here + out = [] + err_code = [] + + request_registration( + "user", + "pass", + "matrix.org", + "shared", + admin=False, + requests=requests, + _print=out.append, + exit=err_code.append, + ) + + # Exit was called + self.assertEqual(err_code, [1]) + + # We got an error message + self.assertIn("ERROR! Received 500 Broken", out) + self.assertNotIn("Success!", out) diff --git a/tests/server.py b/tests/server.py index c63b2c3100..7bee58dff1 100644 --- a/tests/server.py +++ b/tests/server.py @@ -4,9 +4,14 @@ from io import BytesIO from six import text_type import attr +from zope.interface import implementer -from twisted.internet import threads +from twisted.internet import address, threads, udp +from twisted.internet._resolver import HostResolution +from twisted.internet.address import IPv4Address from twisted.internet.defer import Deferred +from twisted.internet.error import DNSLookupError +from twisted.internet.interfaces import IReactorPluggableNameResolver from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactorClock @@ -63,7 +68,9 @@ class FakeChannel(object): self.result["done"] = True def getPeer(self): - return None + # We give an address so that getClientIP returns a non null entry, + # causing us to record the MAU + return address.IPv4Address("TCP", "127.0.0.1", 3423) def getHost(self): return None @@ -91,7 +98,7 @@ class FakeSite: return FakeLogger() -def make_request(method, path, content=b""): +def make_request(method, path, content=b"", access_token=None, request=SynapseRequest): """ Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. @@ -113,9 +120,16 @@ def make_request(method, path, content=b""): site = FakeSite() channel = FakeChannel() - req = SynapseRequest(site, channel) + req = request(site, channel) req.process = lambda: b"" req.content = BytesIO(content) + + if access_token: + req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) + + if content: + req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + req.requestReceived(method, path, b"1.1") return req, channel @@ -147,11 +161,46 @@ def render(request, resource, clock): wait_until_result(clock, request) +@implementer(IReactorPluggableNameResolver) class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. """ + def __init__(self): + self._udp = [] + self.lookups = {} + + class Resolver(object): + def resolveHostName( + _self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics='TCP', + ): + + resolution = HostResolution(hostName) + resolutionReceiver.resolutionBegan(resolution) + if hostName not in self.lookups: + raise DNSLookupError("OH NO") + + resolutionReceiver.addressResolved( + IPv4Address('TCP', self.lookups[hostName], portNumber) + ) + resolutionReceiver.resolutionComplete() + return resolution + + self.nameResolver = Resolver() + super(ThreadedMemoryReactorClock, self).__init__() + + def listenUDP(self, port, protocol, interface='', maxPacketSize=8196): + p = udp.Port(port, protocol, interface, maxPacketSize, self) + p.startListening() + self._udp.append(p) + return p + def callFromThread(self, callback, *args, **kwargs): """ Make the callback fire in the next reactor iteration. @@ -225,6 +274,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): clock.threadpool = ThreadPool() pool.threadpool = ThreadPool() + pool.running = True return d @@ -232,3 +282,84 @@ def get_clock(): clock = ThreadedMemoryReactorClock() hs_clock = Clock(clock) return (clock, hs_clock) + + +@attr.s +class FakeTransport(object): + """ + A twisted.internet.interfaces.ITransport implementation which sends all its data + straight into an IProtocol object: it exists to connect two IProtocols together. + + To use it, instantiate it with the receiving IProtocol, and then pass it to the + sending IProtocol's makeConnection method: + + server = HTTPChannel() + client.makeConnection(FakeTransport(server, self.reactor)) + + If you want bidirectional communication, you'll need two instances. + """ + + other = attr.ib() + """The Protocol object which will receive any data written to this transport. + + :type: twisted.internet.interfaces.IProtocol + """ + + _reactor = attr.ib() + """Test reactor + + :type: twisted.internet.interfaces.IReactorTime + """ + + disconnecting = False + buffer = attr.ib(default=b'') + producer = attr.ib(default=None) + + def getPeer(self): + return None + + def getHost(self): + return None + + def loseConnection(self): + self.disconnecting = True + + def abortConnection(self): + self.disconnecting = True + + def pauseProducing(self): + self.producer.pauseProducing() + + def unregisterProducer(self): + if not self.producer: + return + + self.producer = None + + def registerProducer(self, producer, streaming): + self.producer = producer + self.producerStreaming = streaming + + def _produce(): + d = self.producer.resumeProducing() + d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) + + if not streaming: + self._reactor.callLater(0.0, _produce) + + def write(self, byt): + self.buffer = self.buffer + byt + + def _write(): + if getattr(self.other, "transport") is not None: + self.other.dataReceived(self.buffer) + self.buffer = b"" + return + + self._reactor.callLater(0.0, _write) + + _write() + + def writeSequence(self, seq): + for x in seq: + self.write(x) diff --git a/tests/server_notices/__init__.py b/tests/server_notices/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/server_notices/__init__.py diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py new file mode 100644 index 0000000000..95badc985e --- /dev/null +++ b/tests/server_notices/test_consent.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v2_alpha import sync + +from tests import unittest + + +class ConsentNoticesTests(unittest.HomeserverTestCase): + + servlets = [ + sync.register_servlets, + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + self.consent_notice_message = "consent %(consent_uri)s" + config = self.default_config() + config.user_consent_version = "1" + config.user_consent_server_notice_content = { + "msgtype": "m.text", + "body": self.consent_notice_message, + } + config.public_baseurl = "https://example.com/" + config.form_secret = "123abc" + + config.server_notices_mxid = "@notices:test" + config.server_notices_mxid_display_name = "test display name" + config.server_notices_mxid_avatar_url = None + config.server_notices_room_name = "Server Notices" + + hs = self.setup_test_homeserver(config=config) + + return hs + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("bob", "abc123") + self.access_token = self.login("bob", "abc123") + + def test_get_sync_message(self): + """ + When user consent server notices are enabled, a sync will cause a notice + to fire (in a room which the user is invited to). The notice contains + the notice URL + an authentication code. + """ + # Initial sync, to get the user consent room invite + request, channel = self.make_request( + "GET", "/_matrix/client/r0/sync", access_token=self.access_token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # Get the Room ID to join + room_id = list(channel.json_body["rooms"]["invite"].keys())[0] + + # Join the room + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/" + room_id + "/join", + access_token=self.access_token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # Sync again, to get the message in the room + request, channel = self.make_request( + "GET", "/_matrix/client/r0/sync", access_token=self.access_token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # Get the message + room = channel.json_body["rooms"]["join"][room_id] + messages = [ + x for x in room["timeline"]["events"] if x["type"] == "m.room.message" + ] + + # One message, with the consent URL + self.assertEqual(len(messages), 1) + self.assertTrue( + messages[0]["content"]["body"].startswith( + "consent https://example.com/_matrix/consent" + ) + ) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py new file mode 100644 index 0000000000..4701eedd45 --- /dev/null +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -0,0 +1,207 @@ +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, ServerNoticeMsgType +from synapse.api.errors import ResourceLimitError +from synapse.handlers.auth import AuthHandler +from synapse.server_notices.resource_limits_server_notices import ( + ResourceLimitsServerNotices, +) + +from tests import unittest +from tests.utils import setup_test_homeserver + + +class AuthHandlers(object): + def __init__(self, hs): + self.auth_handler = AuthHandler(hs) + + +class TestResourceLimitsServerNotices(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) + self.hs.handlers = AuthHandlers(self.hs) + self.auth_handler = self.hs.handlers.auth_handler + self.server_notices_sender = self.hs.get_server_notices_sender() + + # relying on [1] is far from ideal, but the only case where + # ResourceLimitsServerNotices class needs to be isolated is this test, + # general code should never have a reason to do so ... + self._rlsn = self.server_notices_sender._server_notices[1] + if not isinstance(self._rlsn, ResourceLimitsServerNotices): + raise Exception("Failed to find reference to ResourceLimitsServerNotices") + + self._rlsn._store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(1000) + ) + self._send_notice = self._rlsn._server_notices_manager.send_notice + self._rlsn._server_notices_manager.send_notice = Mock() + self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None)) + self._rlsn._store.get_events = Mock(return_value=defer.succeed({})) + + self._send_notice = self._rlsn._server_notices_manager.send_notice + + self.hs.config.limit_usage_by_mau = True + self.user_id = "@user_id:test" + + # self.server_notices_mxid = "@server:test" + # self.server_notices_mxid_display_name = None + # self.server_notices_mxid_avatar_url = None + # self.server_notices_room_name = "Server Notices" + + self._rlsn._server_notices_manager.get_notice_room_for_user = Mock( + returnValue="" + ) + self._rlsn._store.add_tag_to_room = Mock() + self._rlsn._store.get_tags_for_room = Mock(return_value={}) + self.hs.config.admin_contact = "mailto:user@test.com" + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_flag_off(self): + """Tests cases where the flags indicate nothing to do""" + # test hs disabled case + self.hs.config.hs_disabled = True + + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + # Test when mau limiting disabled + self.hs.config.hs_disabled = False + self.hs.limit_usage_by_mau = False + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): + """Test when user has blocked notice, but should have it removed""" + + self._rlsn._auth.check_auth_blocking = Mock() + mock_event = Mock( + type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} + ) + self._rlsn._store.get_events = Mock( + return_value=defer.succeed({"123": mock_event}) + ) + + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + # Would be better to check the content, but once == remove blocking event + self._send_notice.assert_called_once() + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): + """Test when user has blocked notice, but notice ought to be there (NOOP)""" + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError(403, 'foo') + ) + + mock_event = Mock( + type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} + ) + self._rlsn._store.get_events = Mock( + return_value=defer.succeed({"123": mock_event}) + ) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_add_blocked_notice(self): + """Test when user does not have blocked notice, but should have one""" + + self._rlsn._auth.check_auth_blocking = Mock( + side_effect=ResourceLimitError(403, 'foo') + ) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + # Would be better to check contents, but 2 calls == set blocking event + self.assertTrue(self._send_notice.call_count == 2) + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): + """Test when user does not have blocked notice, nor should they (NOOP)""" + + self._rlsn._auth.check_auth_blocking = Mock() + + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + + @defer.inlineCallbacks + def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): + + """Test when user is not part of the MAU cohort - this should not ever + happen - but ... + """ + + self._rlsn._auth.check_auth_blocking = Mock() + self._rlsn._store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(None) + ) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + self._send_notice.assert_not_called() + + +class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(self.addCleanup) + self.store = self.hs.get_datastore() + self.server_notices_sender = self.hs.get_server_notices_sender() + self.server_notices_manager = self.hs.get_server_notices_manager() + self.event_source = self.hs.get_event_sources() + + # relying on [1] is far from ideal, but the only case where + # ResourceLimitsServerNotices class needs to be isolated is this test, + # general code should never have a reason to do so ... + self._rlsn = self.server_notices_sender._server_notices[1] + if not isinstance(self._rlsn, ResourceLimitsServerNotices): + raise Exception("Failed to find reference to ResourceLimitsServerNotices") + + self.hs.config.limit_usage_by_mau = True + self.hs.config.hs_disabled = False + self.hs.config.max_mau_value = 5 + self.hs.config.server_notices_mxid = "@server:test" + self.hs.config.server_notices_mxid_display_name = None + self.hs.config.server_notices_mxid_avatar_url = None + self.hs.config.server_notices_room_name = "Test Server Notice Room" + + self.user_id = "@user_id:test" + + self.hs.config.admin_contact = "mailto:user@test.com" + + @defer.inlineCallbacks + def test_server_notice_only_sent_once(self): + self.store.get_monthly_active_count = Mock(return_value=1000) + + self.store.user_last_seen_monthly_active = Mock(return_value=1000) + + # Call the function multiple times to ensure we only send the notice once + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) + + # Now lets get the last load of messages in the service notice room and + # check that there is only one server notice + room_id = yield self.server_notices_manager.get_notice_room_for_user( + self.user_id + ) + + token = yield self.event_source.get_current_token() + events, _ = yield self.store.get_recent_events_for_room( + room_id, limit=100, end_token=token.room_key + ) + + count = 0 + for event in events: + if event.type != EventTypes.Message: + continue + if event.content.get("msgtype") != ServerNoticeMsgType: + continue + + count += 1 + + self.assertEqual(count, 1) diff --git a/tests/state/__init__.py b/tests/state/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/state/__init__.py diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py new file mode 100644 index 0000000000..efd85ebe6c --- /dev/null +++ b/tests/state/test_v2.py @@ -0,0 +1,663 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +from six.moves import zip + +import attr + +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.event_auth import auth_types_for_event +from synapse.events import FrozenEvent +from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store +from synapse.types import EventID + +from tests import unittest + +ALICE = "@alice:example.com" +BOB = "@bob:example.com" +CHARLIE = "@charlie:example.com" +EVELYN = "@evelyn:example.com" +ZARA = "@zara:example.com" + +ROOM_ID = "!test:example.com" + +MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN} +MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN} + + +ORIGIN_SERVER_TS = 0 + + +class FakeEvent(object): + """A fake event we use as a convenience. + + NOTE: Again as a convenience we use "node_ids" rather than event_ids to + refer to events. The event_id has node_id as localpart and example.com + as domain. + """ + def __init__(self, id, sender, type, state_key, content): + self.node_id = id + self.event_id = EventID(id, "example.com").to_string() + self.sender = sender + self.type = type + self.state_key = state_key + self.content = content + + def to_event(self, auth_events, prev_events): + """Given the auth_events and prev_events, convert to a Frozen Event + + Args: + auth_events (list[str]): list of event_ids + prev_events (list[str]): list of event_ids + + Returns: + FrozenEvent + """ + global ORIGIN_SERVER_TS + + ts = ORIGIN_SERVER_TS + ORIGIN_SERVER_TS = ORIGIN_SERVER_TS + 1 + + event_dict = { + "auth_events": [(a, {}) for a in auth_events], + "prev_events": [(p, {}) for p in prev_events], + "event_id": self.node_id, + "sender": self.sender, + "type": self.type, + "content": self.content, + "origin_server_ts": ts, + "room_id": ROOM_ID, + } + + if self.state_key is not None: + event_dict["state_key"] = self.state_key + + return FrozenEvent(event_dict) + + +# All graphs start with this set of events +INITIAL_EVENTS = [ + FakeEvent( + id="CREATE", + sender=ALICE, + type=EventTypes.Create, + state_key="", + content={"creator": ALICE}, + ), + FakeEvent( + id="IMA", + sender=ALICE, + type=EventTypes.Member, + state_key=ALICE, + content=MEMBERSHIP_CONTENT_JOIN, + ), + FakeEvent( + id="IPOWER", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100}}, + ), + FakeEvent( + id="IJR", + sender=ALICE, + type=EventTypes.JoinRules, + state_key="", + content={"join_rule": JoinRules.PUBLIC}, + ), + FakeEvent( + id="IMB", + sender=BOB, + type=EventTypes.Member, + state_key=BOB, + content=MEMBERSHIP_CONTENT_JOIN, + ), + FakeEvent( + id="IMC", + sender=CHARLIE, + type=EventTypes.Member, + state_key=CHARLIE, + content=MEMBERSHIP_CONTENT_JOIN, + ), + FakeEvent( + id="IMZ", + sender=ZARA, + type=EventTypes.Member, + state_key=ZARA, + content=MEMBERSHIP_CONTENT_JOIN, + ), + FakeEvent( + id="START", + sender=ZARA, + type=EventTypes.Message, + state_key=None, + content={}, + ), + FakeEvent( + id="END", + sender=ZARA, + type=EventTypes.Message, + state_key=None, + content={}, + ), +] + +INITIAL_EDGES = [ + "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE", +] + + +class StateTestCase(unittest.TestCase): + def test_ban_vs_pl(self): + events = [ + FakeEvent( + id="PA", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={ + "users": { + ALICE: 100, + BOB: 50, + } + }, + ), + FakeEvent( + id="MA", + sender=ALICE, + type=EventTypes.Member, + state_key=ALICE, + content={"membership": Membership.JOIN}, + ), + FakeEvent( + id="MB", + sender=ALICE, + type=EventTypes.Member, + state_key=BOB, + content={"membership": Membership.BAN}, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + ] + + edges = [ + ["END", "MB", "MA", "PA", "START"], + ["END", "PB", "PA"], + ] + + expected_state_ids = ["PA", "MA", "MB"] + + self.do_check(events, edges, expected_state_ids) + + def test_join_rule_evasion(self): + events = [ + FakeEvent( + id="JR", + sender=ALICE, + type=EventTypes.JoinRules, + state_key="", + content={"join_rules": JoinRules.PRIVATE}, + ), + FakeEvent( + id="ME", + sender=EVELYN, + type=EventTypes.Member, + state_key=EVELYN, + content={"membership": Membership.JOIN}, + ), + ] + + edges = [ + ["END", "JR", "START"], + ["END", "ME", "START"], + ] + + expected_state_ids = ["JR"] + + self.do_check(events, edges, expected_state_ids) + + def test_offtopic_pl(self): + events = [ + FakeEvent( + id="PA", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={ + "users": { + ALICE: 100, + BOB: 50, + } + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + CHARLIE: 50, + }, + }, + ), + FakeEvent( + id="PC", + sender=CHARLIE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + CHARLIE: 0, + }, + }, + ), + ] + + edges = [ + ["END", "PC", "PB", "PA", "START"], + ["END", "PA"], + ] + + expected_state_ids = ["PC"] + + self.do_check(events, edges, expected_state_ids) + + def test_topic_basic(self): + events = [ + FakeEvent( + id="T1", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA1", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T2", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA2", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 0, + }, + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T3", + sender=BOB, + type=EventTypes.Topic, + state_key="", + content={}, + ), + ] + + edges = [ + ["END", "PA2", "T2", "PA1", "T1", "START"], + ["END", "T3", "PB", "PA1"], + ] + + expected_state_ids = ["PA2", "T2"] + + self.do_check(events, edges, expected_state_ids) + + def test_topic_reset(self): + events = [ + FakeEvent( + id="T1", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T2", + sender=BOB, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="MB", + sender=ALICE, + type=EventTypes.Member, + state_key=BOB, + content={"membership": Membership.BAN}, + ), + ] + + edges = [ + ["END", "MB", "T2", "PA", "T1", "START"], + ["END", "T1"], + ] + + expected_state_ids = ["T1", "MB", "PA"] + + self.do_check(events, edges, expected_state_ids) + + def test_topic(self): + events = [ + FakeEvent( + id="T1", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA1", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T2", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="PA2", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 0, + }, + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key='', + content={ + "users": { + ALICE: 100, + BOB: 50, + }, + }, + ), + FakeEvent( + id="T3", + sender=BOB, + type=EventTypes.Topic, + state_key="", + content={}, + ), + FakeEvent( + id="MZ1", + sender=ZARA, + type=EventTypes.Message, + state_key=None, + content={}, + ), + FakeEvent( + id="T4", + sender=ALICE, + type=EventTypes.Topic, + state_key="", + content={}, + ), + ] + + edges = [ + ["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"], + ["END", "MZ1", "T3", "PB", "PA1"], + ] + + expected_state_ids = ["T4", "PA2"] + + self.do_check(events, edges, expected_state_ids) + + def do_check(self, events, edges, expected_state_ids): + """Take a list of events and edges and calculate the state of the + graph at END, and asserts it matches `expected_state_ids` + + Args: + events (list[FakeEvent]) + edges (list[list[str]]): A list of chains of event edges, e.g. + `[[A, B, C]]` are edges A->B and B->C. + expected_state_ids (list[str]): The expected state at END, (excluding + the keys that haven't changed since START). + """ + # We want to sort the events into topological order for processing. + graph = {} + + # node_id -> FakeEvent + fake_event_map = {} + + for ev in itertools.chain(INITIAL_EVENTS, events): + graph[ev.node_id] = set() + fake_event_map[ev.node_id] = ev + + for a, b in pairwise(INITIAL_EDGES): + graph[a].add(b) + + for edge_list in edges: + for a, b in pairwise(edge_list): + graph[a].add(b) + + # event_id -> FrozenEvent + event_map = {} + # node_id -> state + state_at_event = {} + + # We copy the map as the sort consumes the graph + graph_copy = {k: set(v) for k, v in graph.items()} + + for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e): + fake_event = fake_event_map[node_id] + event_id = fake_event.event_id + + prev_events = list(graph[node_id]) + + if len(prev_events) == 0: + state_before = {} + elif len(prev_events) == 1: + state_before = dict(state_at_event[prev_events[0]]) + else: + state_d = resolve_events_with_store( + [state_at_event[n] for n in prev_events], + event_map=event_map, + state_res_store=TestStateResolutionStore(event_map), + ) + + self.assertTrue(state_d.called) + state_before = state_d.result + + state_after = dict(state_before) + if fake_event.state_key is not None: + state_after[(fake_event.type, fake_event.state_key)] = event_id + + auth_types = set(auth_types_for_event(fake_event)) + + auth_events = [] + for key in auth_types: + if key in state_before: + auth_events.append(state_before[key]) + + event = fake_event.to_event(auth_events, prev_events) + + state_at_event[node_id] = state_after + event_map[event_id] = event + + expected_state = {} + for node_id in expected_state_ids: + # expected_state_ids are node IDs rather than event IDs, + # so we have to convert + event_id = EventID(node_id, "example.com").to_string() + event = event_map[event_id] + + key = (event.type, event.state_key) + + expected_state[key] = event_id + + start_state = state_at_event["START"] + end_state = { + key: value + for key, value in state_at_event["END"].items() + if key in expected_state or start_state.get(key) != value + } + + self.assertEqual(expected_state, end_state) + + +class LexicographicalTestCase(unittest.TestCase): + def test_simple(self): + graph = { + "l": {"o"}, + "m": {"n", "o"}, + "n": {"o"}, + "o": set(), + "p": {"o"}, + } + + res = list(lexicographical_topological_sort(graph, key=lambda x: x)) + + self.assertEqual(["o", "l", "n", "m", "p"], res) + + +def pairwise(iterable): + "s -> (s0,s1), (s1,s2), (s2, s3), ..." + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +@attr.s +class TestStateResolutionStore(object): + event_map = attr.ib() + + def get_events(self, event_ids, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + """ + + return { + eid: self.event_map[eid] + for eid in event_ids + if eid in self.event_map + } + + def get_auth_chain(self, event_ids): + """Gets the full auth chain for a set of events (including rejected + events). + + Includes the given event IDs in the result. + + Note that: + 1. All events must be state events. + 2. For v1 rooms this may not have the full auth chain in the + presence of rejected events + + Args: + event_ids (list): The event IDs of the events to fetch the auth + chain for. Must be state events. + + Returns: + Deferred[list[str]]: List of event IDs of the auth chain. + """ + + # Simple DFS for auth chain + result = set() + stack = list(event_ids) + while stack: + event_id = stack.pop() + if event_id in result: + continue + + result.add(event_id) + + event = self.event_map[event_id] + for aid, _ in event.auth_events: + stack.append(aid) + + return list(result) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index c893990454..3f0083831b 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.as_yaml_files = [] - config = Mock( - app_service_config_files=self.as_yaml_files, - event_cache_size=1, - password_providers=[], - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + hs.config.app_service_config_files = self.as_yaml_files + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + self.as_token = "token1" self.as_url = "some_url" self.as_id = "as1" @@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(None, hs) + self.store = ApplicationServiceStore(hs.get_db_conn(), hs) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def setUp(self): self.as_yaml_files = [] - config = Mock( - app_service_config_files=self.as_yaml_files, - event_cache_size=1, - password_providers=[], - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + + hs.config.app_service_config_files = self.as_yaml_files + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + self.db_pool = hs.get_db_pool() + self.engine = hs.database_engine self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, @@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(None, hs) + self.store = TestTransactionStore(hs.get_db_conn(), hs) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files.append(as_token) def _set_state(self, id, state, txn=None): - return self.db_pool.runQuery( - "INSERT INTO application_services_state(as_id, state, last_txn) " - "VALUES(?,?,?)", + return self.db_pool.runOperation( + self.engine.convert_param_style( + "INSERT INTO application_services_state(as_id, state, last_txn) " + "VALUES(?,?,?)" + ), (id, state, txn), ) def _insert_txn(self, as_id, txn_id, events): - return self.db_pool.runQuery( - "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " - "VALUES(?,?,?)", + return self.db_pool.runOperation( + self.engine.convert_param_style( + "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " + "VALUES(?,?,?)" + ), (as_id, txn_id, json.dumps([e.event_id for e in events])), ) def _set_last_txn(self, as_id, txn_id): - return self.db_pool.runQuery( - "INSERT INTO application_services_state(as_id, last_txn, state) " - "VALUES(?,?,?)", + return self.db_pool.runOperation( + self.engine.convert_param_style( + "INSERT INTO application_services_state(as_id, last_txn, state) " + "VALUES(?,?,?)" + ), (as_id, txn_id, ApplicationServiceState.UP), ) @defer.inlineCallbacks def test_get_appservice_state_none(self): - service = Mock(id=999) + service = Mock(id="999") state = yield self.store.get_appservice_state(service) self.assertEquals(None, state) @@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): service = Mock(id=self.as_list[1]["id"]) yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) rows = yield self.db_pool.runQuery( - "SELECT as_id FROM application_services_state WHERE state=?", + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), (ApplicationServiceState.DOWN,), ) self.assertEquals(service.id, rows[0][0]) @@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) yield self.store.set_appservice_state(service, ApplicationServiceState.UP) rows = yield self.db_pool.runQuery( - "SELECT as_id FROM application_services_state WHERE state=?", + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), (ApplicationServiceState.UP,), ) self.assertEquals(service.id, rows[0][0]) @@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) res = yield self.db_pool.runQuery( - "SELECT last_txn FROM application_services_state WHERE as_id=?", + self.engine.convert_param_style( + "SELECT last_txn FROM application_services_state WHERE as_id=?" + ), (service.id,), ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) res = yield self.db_pool.runQuery( - "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) @@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) res = yield self.db_pool.runQuery( - "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?", + self.engine.convert_param_style( + "SELECT last_txn, state FROM application_services_state WHERE as_id=?" + ), (service.id,), ) self.assertEquals(1, len(res)) @@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(ApplicationServiceState.UP, res[0][1]) res = yield self.db_pool.runQuery( - "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), ) self.assertEquals(0, len(res)) @@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") - config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - datastore=Mock(), - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - ApplicationServiceStore(None, hs) + hs.config.app_service_config_files = [f1, f2] + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + + ApplicationServiceStore(hs.get_db_conn(), hs) @defer.inlineCallbacks def test_duplicate_ids(self): f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") - config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - datastore=Mock(), - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + hs.config.app_service_config_files = [f1, f2] + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(None, hs) + ApplicationServiceStore(hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) @@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") - config = Mock( - app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[] - ) hs = yield setup_test_homeserver( - self.addCleanup, - config=config, - datastore=Mock(), - federation_sender=Mock(), - federation_client=Mock(), + self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) + hs.config.app_service_config_files = [f1, f2] + hs.config.event_cache_size = 1 + hs.config.password_providers = [] + with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(None, hs) + ApplicationServiceStore(hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 7cb5f0e4cf..829f47d2e8 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -20,11 +20,11 @@ from mock import Mock from twisted.internet import defer -from synapse.server import HomeServer from synapse.storage._base import SQLBaseStore from synapse.storage.engines import create_engine from tests import unittest +from tests.utils import TestHomeServer class SQLBaseStoreTestCase(unittest.TestCase): @@ -51,7 +51,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config.event_cache_size = 1 config.database_config = {"name": "sqlite3"} - hs = HomeServer( + hs = TestHomeServer( "test", db_pool=self.db_pool, config=config, diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c2e88bdbaf..4577e9422b 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,35 +13,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from mock import Mock from twisted.internet import defer -import tests.unittest -import tests.utils +from synapse.http.site import XForwardedForRequest +from synapse.rest.client.v1 import admin, login + +from tests import unittest -class ClientIpStoreTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore - self.clock = None # type: tests.utils.MockClock +class ClientIpStoreTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs - @defer.inlineCallbacks - def setUp(self): - self.hs = yield tests.utils.setup_test_homeserver(self.addCleanup) + def prepare(self, hs, reactor, clock): self.store = self.hs.get_datastore() - self.clock = self.hs.get_clock() - @defer.inlineCallbacks def test_insert_new_client_ip(self): - self.clock.now = 12345678 + self.reactor.advance(12345678) + user_id = "@user:id" - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - result = yield self.store.get_last_client_ip_by_device(user_id, "device_id") + # Trigger the storage loop + self.reactor.advance(10) + + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, "device_id") + ) r = result[(user_id, "device_id")] self.assertDictContainsSubset( @@ -55,18 +62,18 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): r, ) - @defer.inlineCallbacks def test_disabled_monthly_active_user(self): self.hs.config.limit_usage_by_mau = False self.hs.config.max_mau_value = 50 user_id = "@user:server" - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - @defer.inlineCallbacks def test_adding_monthly_active_user_when_full(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 @@ -76,40 +83,119 @@ class ClientIpStoreTestCase(tests.unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(lots_of_users) ) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - @defer.inlineCallbacks def test_adding_monthly_active_user_when_space(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) - @defer.inlineCallbacks def test_updating_monthly_active_user_when_space(self): self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 50 user_id = "@user:server" + self.get_success( + self.store.register(user_id=user_id, token="123", password_hash=None) + ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" - ) - yield self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + # Trigger the saving loop + self.reactor.advance(10) + + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", "device_id" + ) ) - active = yield self.store.user_last_seen_monthly_active(user_id) + active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + + +class ClientIpAuthTestCase(unittest.HomeserverTestCase): + + servlets = [admin.register_servlets, login.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + return hs + + def prepare(self, hs, reactor, clock): + self.store = self.hs.get_datastore() + self.user_id = self.register_user("bob", "abc123", True) + + def test_request_with_xforwarded(self): + """ + The IP in X-Forwarded-For is entered into the client IPs table. + """ + self._runtest( + {b"X-Forwarded-For": b"127.9.0.1"}, + "127.9.0.1", + {"request": XForwardedForRequest}, + ) + + def test_request_from_getPeer(self): + """ + The IP returned by getPeer is entered into the client IPs table, if + there's no X-Forwarded-For header. + """ + self._runtest({}, "127.0.0.1", {}) + + def _runtest(self, headers, expected_ip, make_request_args): + device_id = "bleb" + + access_token = self.login("bob", "abc123", device_id=device_id) + + # Advance to a known time + self.reactor.advance(123456 - self.reactor.seconds()) + + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/admin/users/" + self.user_id, + access_token=access_token, + **make_request_args + ) + request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") + + # Add the optional headers + for h, v in headers.items(): + request.requestHeaders.addRawHeader(h, v) + self.render(request) + + # Advance so the save loop occurs + self.reactor.advance(100) + + result = self.get_success( + self.store.get_last_client_ip_by_device(self.user_id, device_id) + ) + r = result[(self.user_id, device_id)] + self.assertDictContainsSubset( + { + "user_id": self.user_id, + "device_id": device_id, + "ip": expected_ip, + "user_agent": "Mozzila pizza", + "last_seen": 123456100, + }, + r, + ) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index b4510c1c8d..4e128e1047 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.directory import DirectoryStore from synapse.types import RoomAlias, RoomID from tests import unittest @@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = DirectoryStore(None, hs) + self.store = hs.get_datastore() self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 2fdf34fdf6..0d4e74d637 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ( "INSERT INTO events (" " room_id, event_id, type, depth, topological_ordering," - " content, processed, outlier) " - "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" + " content, processed, outlier, stream_ordering) " + "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)" ), - (room_id, event_id, i, i, True, False), + (room_id, event_id, i, i, True, False, i), ) txn.execute( diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index f2ed866ae7..832e379a83 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -12,26 +12,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock from twisted.internet import defer -import tests.unittest -import tests.utils -from tests.utils import setup_test_homeserver +from tests.unittest import HomeserverTestCase FORTY_DAYS = 40 * 24 * 60 * 60 -class MonthlyActiveUsersTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs) +class MonthlyActiveUsersTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = self.hs.get_datastore() + hs = self.setup_test_homeserver() + self.store = hs.get_datastore() + hs.config.limit_usage_by_mau = True + hs.config.max_mau_value = 50 + # Advance the clock a bit + reactor.advance(FORTY_DAYS) + + return hs - @defer.inlineCallbacks def test_initialise_reserved_users(self): self.hs.config.max_mau_value = 5 user1 = "@user1:server" @@ -44,88 +45,178 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase): ] user_num = len(threepids) - yield self.store.register(user_id=user1, token="123", password_hash=None) - - yield self.store.register(user_id=user2, token="456", password_hash=None) + self.store.register(user_id=user1, token="123", password_hash=None) + self.store.register(user_id=user2, token="456", password_hash=None) + self.pump() now = int(self.hs.get_clock().time_msec()) - yield self.store.user_add_threepid(user1, "email", user1_email, now, now) - yield self.store.user_add_threepid(user2, "email", user2_email, now, now) - yield self.store.initialise_reserved_users(threepids) + self.store.user_add_threepid(user1, "email", user1_email, now, now) + self.store.user_add_threepid(user2, "email", user2_email, now, now) - active_count = yield self.store.get_monthly_active_count() + self.store.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) + self.pump() + + active_count = self.store.get_monthly_active_count() # Test total counts - self.assertEquals(active_count, user_num) + self.assertEquals(self.get_success(active_count), user_num) # Test user is marked as active - - timestamp = yield self.store.user_last_seen_monthly_active(user1) - self.assertTrue(timestamp) - timestamp = yield self.store.user_last_seen_monthly_active(user2) - self.assertTrue(timestamp) + timestamp = self.store.user_last_seen_monthly_active(user1) + self.assertTrue(self.get_success(timestamp)) + timestamp = self.store.user_last_seen_monthly_active(user2) + self.assertTrue(self.get_success(timestamp)) # Test that users are never removed from the db. self.hs.config.max_mau_value = 0 - self.hs.get_clock().advance_time(FORTY_DAYS) + self.reactor.advance(FORTY_DAYS) - yield self.store.reap_monthly_active_users() + self.store.reap_monthly_active_users() + self.pump() - active_count = yield self.store.get_monthly_active_count() - self.assertEquals(active_count, user_num) + active_count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(active_count), user_num) - # Test that regalar users are removed from the db + # Test that regular users are removed from the db ru_count = 2 - yield self.store.upsert_monthly_active_user("@ru1:server") - yield self.store.upsert_monthly_active_user("@ru2:server") - active_count = yield self.store.get_monthly_active_count() + self.store.upsert_monthly_active_user("@ru1:server") + self.store.upsert_monthly_active_user("@ru2:server") + self.pump() - self.assertEqual(active_count, user_num + ru_count) + active_count = self.store.get_monthly_active_count() + self.assertEqual(self.get_success(active_count), user_num + ru_count) self.hs.config.max_mau_value = user_num - yield self.store.reap_monthly_active_users() + self.store.reap_monthly_active_users() + self.pump() - active_count = yield self.store.get_monthly_active_count() - self.assertEquals(active_count, user_num) + active_count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(active_count), user_num) - @defer.inlineCallbacks def test_can_insert_and_count_mau(self): - count = yield self.store.get_monthly_active_count() - self.assertEqual(0, count) + count = self.store.get_monthly_active_count() + self.assertEqual(0, self.get_success(count)) - yield self.store.upsert_monthly_active_user("@user:server") - count = yield self.store.get_monthly_active_count() + self.store.upsert_monthly_active_user("@user:server") + self.pump() - self.assertEqual(1, count) + count = self.store.get_monthly_active_count() + self.assertEqual(1, self.get_success(count)) - @defer.inlineCallbacks def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = yield self.store.user_last_seen_monthly_active(user_id1) - self.assertFalse(result == 0) - yield self.store.upsert_monthly_active_user(user_id1) - yield self.store.upsert_monthly_active_user(user_id2) - result = yield self.store.user_last_seen_monthly_active(user_id1) - self.assertTrue(result > 0) - result = yield self.store.user_last_seen_monthly_active(user_id3) - self.assertFalse(result == 0) + result = self.store.user_last_seen_monthly_active(user_id1) + self.assertFalse(self.get_success(result) == 0) + + self.store.upsert_monthly_active_user(user_id1) + self.store.upsert_monthly_active_user(user_id2) + self.pump() + + result = self.store.user_last_seen_monthly_active(user_id1) + self.assertGreater(self.get_success(result), 0) + + result = self.store.user_last_seen_monthly_active(user_id3) + self.assertNotEqual(self.get_success(result), 0) - @defer.inlineCallbacks def test_reap_monthly_active_users(self): self.hs.config.max_mau_value = 5 initial_users = 10 for i in range(initial_users): - yield self.store.upsert_monthly_active_user("@user%d:server" % i) - count = yield self.store.get_monthly_active_count() - self.assertTrue(count, initial_users) - yield self.store.reap_monthly_active_users() - count = yield self.store.get_monthly_active_count() - self.assertEquals(count, initial_users - self.hs.config.max_mau_value) - - self.hs.get_clock().advance_time(FORTY_DAYS) - yield self.store.reap_monthly_active_users() - count = yield self.store.get_monthly_active_count() - self.assertEquals(count, 0) + self.store.upsert_monthly_active_user("@user%d:server" % i) + self.pump() + + count = self.store.get_monthly_active_count() + self.assertTrue(self.get_success(count), initial_users) + + self.store.reap_monthly_active_users() + self.pump() + count = self.store.get_monthly_active_count() + self.assertEquals( + self.get_success(count), initial_users - self.hs.config.max_mau_value + ) + + self.reactor.advance(FORTY_DAYS) + self.store.reap_monthly_active_users() + self.pump() + + count = self.store.get_monthly_active_count() + self.assertEquals(self.get_success(count), 0) + + def test_populate_monthly_users_is_guest(self): + # Test that guest users are not added to mau list + user_id = "user_id" + self.store.register( + user_id=user_id, token="123", password_hash=None, make_guest=True + ) + self.store.upsert_monthly_active_user = Mock() + self.store.populate_monthly_active_users(user_id) + self.pump() + self.store.upsert_monthly_active_user.assert_not_called() + + def test_populate_monthly_users_should_update(self): + self.store.upsert_monthly_active_user = Mock() + + self.store.is_trial_user = Mock( + return_value=defer.succeed(False) + ) + + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed(None) + ) + self.store.populate_monthly_active_users('user_id') + self.pump() + self.store.upsert_monthly_active_user.assert_called_once() + + def test_populate_monthly_users_should_not_update(self): + self.store.upsert_monthly_active_user = Mock() + + self.store.is_trial_user = Mock( + return_value=defer.succeed(False) + ) + self.store.user_last_seen_monthly_active = Mock( + return_value=defer.succeed( + self.hs.get_clock().time_msec() + ) + ) + self.store.populate_monthly_active_users('user_id') + self.pump() + self.store.upsert_monthly_active_user.assert_not_called() + + def test_get_reserved_real_user_account(self): + # Test no reserved users, or reserved threepids + count = self.store.get_registered_reserved_users_count() + self.assertEquals(self.get_success(count), 0) + # Test reserved users but no registered users + + user1 = '@user1:example.com' + user2 = '@user2:example.com' + user1_email = 'user1@example.com' + user2_email = 'user2@example.com' + threepids = [ + {'medium': 'email', 'address': user1_email}, + {'medium': 'email', 'address': user2_email}, + ] + self.hs.config.mau_limits_reserved_threepids = threepids + self.store.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) + + self.pump() + count = self.store.get_registered_reserved_users_count() + self.assertEquals(self.get_success(count), 0) + + # Test reserved registed users + self.store.register(user_id=user1, token="123", password_hash=None) + self.store.register(user_id=user2, token="456", password_hash=None) + self.pump() + + now = int(self.hs.get_clock().time_msec()) + self.store.user_add_threepid(user1, "email", user1_email, now, now) + self.store.user_add_threepid(user2, "email", user2_email, now, now) + count = self.store.get_registered_reserved_users_count() + self.assertEquals(self.get_success(count), len(threepids)) diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index b5b58ff660..c7a63f39b9 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -16,19 +16,18 @@ from twisted.internet import defer -from synapse.storage.presence import PresenceStore from synapse.types import UserID from tests import unittest -from tests.utils import MockClock, setup_test_homeserver +from tests.utils import setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock()) + hs = yield setup_test_homeserver(self.addCleanup) - self.store = PresenceStore(None, hs) + self.store = hs.get_datastore() self.u_apple = UserID.from_string("@apple:test") self.u_banana = UserID.from_string("@banana:test") diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index dc3a2fd976..c125a0d797 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(None, hs) + self.store = ProfileStore(hs.get_db_conn(), hs) self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py new file mode 100644 index 0000000000..f671599cb8 --- /dev/null +++ b/tests/storage/test_purge.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.client.v1 import room + +from tests.unittest import HomeserverTestCase + + +class PurgeTests(HomeserverTestCase): + + user_id = "@red:server" + servlets = [room.register_servlets] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver("server", http_client=None) + return hs + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_purge(self): + """ + Purging a room will delete everything before the topological point. + """ + # Send four messages to the room + first = self.helper.send(self.room_id, body="test1") + second = self.helper.send(self.room_id, body="test2") + third = self.helper.send(self.room_id, body="test3") + last = self.helper.send(self.room_id, body="test4") + + storage = self.hs.get_datastore() + + # Get the topological token + event = storage.get_topological_token_for_event(last["event_id"]) + self.pump() + event = self.successResultOf(event) + + # Purge everything before this topological token + purge = storage.purge_history(self.room_id, event, True) + self.pump() + self.assertEqual(self.successResultOf(purge), None) + + # Try and get the events + get_first = storage.get_event(first["event_id"]) + get_second = storage.get_event(second["event_id"]) + get_third = storage.get_event(third["event_id"]) + get_last = storage.get_event(last["event_id"]) + self.pump() + + # 1-3 should fail and last will succeed, meaning that 1-3 are deleted + # and last is not. + self.failureResultOf(get_first) + self.failureResultOf(get_second) + self.failureResultOf(get_third) + self.successResultOf(get_last) + + def test_purge_wont_delete_extrems(self): + """ + Purging a room will delete everything before the topological point. + """ + # Send four messages to the room + first = self.helper.send(self.room_id, body="test1") + second = self.helper.send(self.room_id, body="test2") + third = self.helper.send(self.room_id, body="test3") + last = self.helper.send(self.room_id, body="test4") + + storage = self.hs.get_datastore() + + # Set the topological token higher than it should be + event = storage.get_topological_token_for_event(last["event_id"]) + self.pump() + event = self.successResultOf(event) + event = "t{}-{}".format( + *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) + ) + + # Purge everything before this topological token + purge = storage.purge_history(self.room_id, event, True) + self.pump() + f = self.failureResultOf(purge) + self.assertIn("greater than forward", f.value.args[0]) + + # Try and get the events + get_first = storage.get_event(first["event_id"]) + get_second = storage.get_event(second["event_id"]) + get_third = storage.get_event(third["event_id"]) + get_last = storage.get_event(last["event_id"]) + self.pump() + + # Nothing is deleted. + self.successResultOf(get_first) + self.successResultOf(get_second) + self.successResultOf(get_third) + self.successResultOf(get_last) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index c4e9fb72bf..02bf975fbf 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.types import RoomID, UserID from tests import unittest -from tests.utils import setup_test_homeserver +from tests.utils import create_room, setup_test_homeserver class RedactionTestCase(unittest.TestCase): @@ -41,6 +41,8 @@ class RedactionTestCase(unittest.TestCase): self.room1 = RoomID.from_string("!abc123:test") + yield create_room(hs, self.room1.to_string(), self.u_alice.to_string()) + self.depth = 1 @defer.inlineCallbacks diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4eda122edc..3dfb7b903a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -46,6 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "consent_version": None, "consent_server_notice_sent": None, "appservice_id": None, + "creation_ts": 1000, }, (yield self.store.get_user_by_id(self.user_id)), ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index c83ef60062..978c66133d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.types import RoomID, UserID from tests import unittest -from tests.utils import setup_test_homeserver +from tests.utils import create_room, setup_test_homeserver class RoomMemberStoreTestCase(unittest.TestCase): @@ -45,6 +45,8 @@ class RoomMemberStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abc123:test") + yield create_room(hs, self.room.to_string(), self.u_alice.to_string()) + @defer.inlineCallbacks def inject_room_member(self, room, user, membership, replaces_state=None): builder = self.event_builder_factory.new( diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index ebfd969b36..086a39d834 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -18,6 +18,7 @@ import logging from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID import tests.unittest @@ -75,6 +76,45 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(len(s1), len(s2)) @defer.inlineCallbacks + def test_get_state_groups_ids(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_map = list(state_group_map.values())[0] + self.assertDictEqual( + state_map, + { + (EventTypes.Create, ''): e1.event_id, + (EventTypes.Name, ''): e2.event_id, + }, + ) + + @defer.inlineCallbacks + def test_get_state_groups(self): + e1 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, '', {} + ) + e2 = yield self.inject_state_event( + self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} + ) + + state_group_map = yield self.store.get_state_groups( + self.room, [e2.event_id]) + self.assertEqual(len(state_group_map), 1) + state_list = list(state_group_map.values())[0] + + self.assertEqual( + {ev.event_id for ev in state_list}, + {e1.event_id, e2.event_id}, + ) + + @defer.inlineCallbacks def test_get_state_for_event(self): # this defaults to a linear DAG as each new injection defaults to whatever @@ -109,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we get the full state as of the final event state = yield self.store.get_state_for_event( - e5.event_id, None, filtered_types=None + e5.event_id, ) self.assertIsNotNone(e4) @@ -127,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can filter to the m.room.name event (with a '' state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, '')], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, '')]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Name, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) state = yield self.store.get_state_for_event( - e5.event_id, [(EventTypes.Member, None)], filtered_types=None + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) self.assertStateMapEqual( {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state ) - # check we can use filtered_types to grab a specific room member - # without filtering out the other event types + # check we can grab a specific room member without filtering out the + # other event types state = yield self.store.get_state_for_event( e5.event_id, - [(EventTypes.Member, self.u_alice.to_string())], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ) ) self.assertStateMapEqual( @@ -165,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase): state, ) - # check that types=[], filtered_types=[EventTypes.Member] - # doesn't return all members + # check that we can grab everything except members state = yield self.store.get_state_for_event( - e5.event_id, [], filtered_types=[EventTypes.Member] + e5.event_id, state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertStateMapEqual( @@ -176,16 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase): ) ####################################################### - # _get_some_state_from_cache tests against a full cache + # _get_state_for_group_using_cache tests against a full cache ####################################################### room_id = self.room.to_string() group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id]) group = list(group_ids.keys())[0] - # test _get_some_state_from_cache correctly filters out members with types=[] - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - group, [], filtered_types=[EventTypes.Member] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -197,9 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_members_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({}, state_dict) + + # test _get_state_for_group_using_cache correctly filters in members + # with wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -207,6 +274,22 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, + }, + state_dict, + ) + + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_members_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), + ) + + self.assertEqual(is_all, True) + self.assertDictEqual( + { (e3.type, e3.state_key): e3.event_id, # e4 is overwritten by e5 (e5.type, e5.state_key): e5.event_id, @@ -214,11 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, True) @@ -226,15 +313,31 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, - (e5.type, e5.state_key): e5.event_id, }, state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - group, [(EventTypes.Member, e5.state_key)], filtered_types=None + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_members_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_members_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) @@ -254,9 +357,6 @@ class StateStoreTestCase(tests.unittest.TestCase): { (e1.type, e1.state_key): e1.event_id, (e2.type, e2.state_key): e2.event_id, - (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 - (e5.type, e5.state_key): e5.event_id, }, ) @@ -267,11 +367,7 @@ class StateStoreTestCase(tests.unittest.TestCase): key=group, value=state_dict_ids, # list fetched keys so it knows it's partial - fetched_keys=( - (e1.type, e1.state_key), - (e3.type, e3.state_key), - (e5.type, e5.state_key), - ), + fetched_keys=((e1.type, e1.state_key),), ) (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( @@ -279,73 +375,118 @@ class StateStoreTestCase(tests.unittest.TestCase): ) self.assertEqual(is_all, False) - self.assertEqual( - known_absent, - set( - [ - (e1.type, e1.state_key), - (e3.type, e3.state_key), - (e5.type, e5.state_key), - ] - ), - ) - self.assertDictEqual( - state_dict_ids, - { - (e1.type, e1.state_key): e1.event_id, - (e3.type, e3.state_key): e3.event_id, - (e5.type, e5.state_key): e5.event_id, - }, - ) + self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) + self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ # test that things work with a partial cache - # test _get_some_state_from_cache correctly filters out members with types=[] + # test _get_state_for_group_using_cache correctly filters out members + # with types=[] room_id = self.room.to_string() - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - group, [], filtered_types=[EventTypes.Member] + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), ) self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - # test _get_some_state_from_cache correctly filters in members wildcard types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] + room_id = self.room.to_string() + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_members_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: set()}, + include_others=True, + ), + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({}, state_dict) + + # test _get_state_for_group_using_cache correctly filters in members + # wildcard types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), ) self.assertEqual(is_all, False) + self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_members_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: None}, + include_others=True, + ), + ) + + self.assertEqual(is_all, True) self.assertDictEqual( { - (e1.type, e1.state_key): e1.event_id, (e3.type, e3.state_key): e3.event_id, - # e4 is overwritten by e5 (e5.type, e5.state_key): e5.event_id, }, state_dict, ) - # test _get_some_state_from_cache correctly filters in members with specific types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, group, - [(EventTypes.Member, e5.state_key)], - filtered_types=[EventTypes.Member], + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), ) self.assertEqual(is_all, False) - self.assertDictEqual( - { - (e1.type, e1.state_key): e1.event_id, - (e5.type, e5.state_key): e5.event_id, - }, - state_dict, + self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_members_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=True, + ), + ) + + self.assertEqual(is_all, True) + self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + + # test _get_state_for_group_using_cache correctly filters in members + # with specific types + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) - # test _get_some_state_from_cache correctly filters in members with specific types - # and no filtered_types - (state_dict, is_all) = yield self.store._get_some_state_from_cache( - group, [(EventTypes.Member, e5.state_key)], filtered_types=None + self.assertEqual(is_all, False) + self.assertDictEqual({}, state_dict) + + (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( + self.store._state_group_members_cache, + group, + state_filter=StateFilter( + types={EventTypes.Member: {e5.state_key}}, + include_others=False, + ), ) self.assertEqual(is_all, True) diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py new file mode 100644 index 0000000000..14169afa96 --- /dev/null +++ b/tests/storage/test_transactions.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.unittest import HomeserverTestCase + + +class TransactionStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() + + def test_get_set_transactions(self): + """Tests that we can successfully get a non-existent entry for + destination retries, as well as testing tht we can set and get + correctly. + """ + d = self.store.get_destination_retry_timings("example.com") + r = self.get_success(d) + self.assertIsNone(r) + + d = self.store.set_destination_retry_timings("example.com", 50, 100) + self.get_success(d) + + d = self.store.get_destination_retry_timings("example.com") + r = self.get_success(d) + + self.assert_dict({"retry_last_ts": 50, "retry_interval": 100}, r) + + def test_initial_set_transactions(self): + """Tests that we can successfully set the destination retries (there + was a bug around invalidating the cache that broke this) + """ + d = self.store.set_destination_retry_timings("example.com", 50, 100) + self.get_success(d) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index b46e0ea7e2..0dde1ab2fe 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(None, self.hs) + self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_federation.py b/tests/test_federation.py index 2540604fcc..952a0a7b51 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -6,6 +6,7 @@ from twisted.internet.defer import maybeDeferred, succeed from synapse.events import FrozenEvent from synapse.types import Requester, UserID from synapse.util import Clock +from synapse.util.logcontext import LoggingContext from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver @@ -117,9 +118,10 @@ class MessageAcceptTests(unittest.TestCase): } ) - d = self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True - ) + with LoggingContext(request="lying_event"): + d = self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) # Step the reactor, so the database fetches come back self.reactor.advance(1) @@ -139,107 +141,3 @@ class MessageAcceptTests(unittest.TestCase): self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") - - def test_cant_hide_past_history(self): - """ - If you send a message, you must be able to provide the direct - prev_events that said event references. - """ - - def post_json(destination, path, data, headers=None, timeout=0): - if path.startswith("/_matrix/federation/v1/get_missing_events/"): - return { - "events": [ - { - "room_id": self.room_id, - "sender": "@baduser:test.serv", - "event_id": "three:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.message", - "origin": "test.serv", - "content": "hewwo?", - "auth_events": [], - "prev_events": [("four:test.serv", {})], - } - ] - } - - self.http_client.post_json = post_json - - def get_json(destination, path, args, headers=None): - if path.startswith("/_matrix/federation/v1/state_ids/"): - d = self.successResultOf( - self.homeserver.datastore.get_state_ids_for_event("one:test.serv") - ) - - return succeed( - { - "pdu_ids": [ - y - for x, y in d.items() - if x == ("m.room.member", "@us:test") - ], - "auth_chain_ids": list(d.values()), - } - ) - - self.http_client.get_json = get_json - - # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id - ) - )[0] - - # Make a good event - good_event = FrozenEvent( - { - "room_id": self.room_id, - "sender": "@baduser:test.serv", - "event_id": "one:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.message", - "origin": "test.serv", - "content": "hewwo?", - "auth_events": [], - "prev_events": [(most_recent, {})], - } - ) - - d = self.handler.on_receive_pdu( - "test.serv", good_event, sent_to_us_directly=True - ) - self.reactor.advance(1) - self.assertEqual(self.successResultOf(d), None) - - bad_event = FrozenEvent( - { - "room_id": self.room_id, - "sender": "@baduser:test.serv", - "event_id": "two:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.message", - "origin": "test.serv", - "content": "hewwo?", - "auth_events": [], - "prev_events": [("one:test.serv", {}), ("three:test.serv", {})], - } - ) - - d = self.handler.on_receive_pdu( - "test.serv", bad_event, sent_to_us_directly=True - ) - self.reactor.advance(1) - - extrem = maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id - ) - self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv") - - state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id) - self.reactor.advance(1) - self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys()) diff --git a/tests/test_mau.py b/tests/test_mau.py new file mode 100644 index 0000000000..bdbacb8448 --- /dev/null +++ b/tests/test_mau.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /rooms paths.""" + +import json + +from mock import Mock, NonCallableMock + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes, HttpResponseException, SynapseError +from synapse.http.server import JsonResource +from synapse.rest.client.v2_alpha import register, sync +from synapse.util import Clock + +from tests import unittest +from tests.server import ( + ThreadedMemoryReactorClock, + make_request, + render, + setup_test_homeserver, +) + + +class TestMauLimit(unittest.TestCase): + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + + self.hs = setup_test_homeserver( + self.addCleanup, + "red", + http_client=None, + clock=self.clock, + reactor=self.reactor, + federation_client=Mock(), + ratelimiter=NonCallableMock(spec_set=["send_message"]), + ) + + self.store = self.hs.get_datastore() + + self.hs.config.registrations_require_3pid = [] + self.hs.config.enable_registration_captcha = False + self.hs.config.recaptcha_public_key = [] + + self.hs.config.limit_usage_by_mau = True + self.hs.config.hs_disabled = False + self.hs.config.max_mau_value = 2 + self.hs.config.mau_trial_days = 0 + self.hs.config.server_notices_mxid = "@server:red" + self.hs.config.server_notices_mxid_display_name = None + self.hs.config.server_notices_mxid_avatar_url = None + self.hs.config.server_notices_room_name = "Test Server Notice Room" + + self.resource = JsonResource(self.hs) + register.register_servlets(self.hs, self.resource) + sync.register_servlets(self.hs, self.resource) + + def test_simple_deny_mau(self): + # Create and sync so that the MAU counts get updated + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + + # We've created and activated two users, we shouldn't be able to + # register new users + with self.assertRaises(SynapseError) as cm: + self.create_user("kermit3") + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_allowed_after_a_month_mau(self): + # Create and sync so that the MAU counts get updated + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + + # Advance time by 31 days + self.reactor.advance(31 * 24 * 60 * 60) + + self.store.reap_monthly_active_users() + + self.reactor.advance(0) + + # We should be able to register more users + token3 = self.create_user("kermit3") + self.do_sync_for_user(token3) + + def test_trial_delay(self): + self.hs.config.mau_trial_days = 1 + + # We should be able to register more than the limit initially + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + token3 = self.create_user("kermit3") + self.do_sync_for_user(token3) + + # Advance time by 2 days + self.reactor.advance(2 * 24 * 60 * 60) + + # Two users should be able to sync + self.do_sync_for_user(token1) + self.do_sync_for_user(token2) + + # But the third should fail + with self.assertRaises(SynapseError) as cm: + self.do_sync_for_user(token3) + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + # And new registrations are now denied too + with self.assertRaises(SynapseError) as cm: + self.create_user("kermit4") + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_trial_users_cant_come_back(self): + self.hs.config.mau_trial_days = 1 + + # We should be able to register more than the limit initially + token1 = self.create_user("kermit1") + self.do_sync_for_user(token1) + token2 = self.create_user("kermit2") + self.do_sync_for_user(token2) + token3 = self.create_user("kermit3") + self.do_sync_for_user(token3) + + # Advance time by 2 days + self.reactor.advance(2 * 24 * 60 * 60) + + # Two users should be able to sync + self.do_sync_for_user(token1) + self.do_sync_for_user(token2) + + # Advance by 2 months so everyone falls out of MAU + self.reactor.advance(60 * 24 * 60 * 60) + self.store.reap_monthly_active_users() + self.reactor.advance(0) + + # We can create as many new users as we want + token4 = self.create_user("kermit4") + self.do_sync_for_user(token4) + token5 = self.create_user("kermit5") + self.do_sync_for_user(token5) + token6 = self.create_user("kermit6") + self.do_sync_for_user(token6) + + # users 2 and 3 can come back to bring us back up to MAU limit + self.do_sync_for_user(token2) + self.do_sync_for_user(token3) + + # New trial users can still sync + self.do_sync_for_user(token4) + self.do_sync_for_user(token5) + self.do_sync_for_user(token6) + + # But old user cant + with self.assertRaises(SynapseError) as cm: + self.do_sync_for_user(token1) + + e = cm.exception + self.assertEqual(e.code, 403) + self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def create_user(self, localpart): + request_data = json.dumps( + { + "username": localpart, + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + } + ) + + request, channel = make_request("POST", "/register", request_data) + render(request, self.resource, self.reactor) + + if channel.code != 200: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"] + ).to_synapse_error() + + access_token = channel.json_body["access_token"] + + return access_token + + def do_sync_for_user(self, token): + request, channel = make_request( + "GET", "/sync", access_token=token.encode('ascii') + ) + render(request, self.resource, self.reactor) + + if channel.code != 200: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"] + ).to_synapse_error() diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000000..17897711a1 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.metrics import InFlightGauge + +from tests import unittest + + +class TestMauLimit(unittest.TestCase): + def test_basic(self): + gauge = InFlightGauge( + "test1", "", + labels=["test_label"], + sub_metrics=["foo", "bar"], + ) + + def handle1(metrics): + metrics.foo += 2 + metrics.bar = max(metrics.bar, 5) + + def handle2(metrics): + metrics.foo += 3 + metrics.bar = max(metrics.bar, 7) + + gauge.register(("key1",), handle1) + + self.assert_dict({ + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, self.get_metrics_from_gauge(gauge)) + + gauge.unregister(("key1",), handle1) + + self.assert_dict({ + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, self.get_metrics_from_gauge(gauge)) + + gauge.register(("key1",), handle1) + gauge.register(("key2",), handle2) + + self.assert_dict({ + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, self.get_metrics_from_gauge(gauge)) + + gauge.unregister(("key2",), handle2) + gauge.register(("key1",), handle2) + + self.assert_dict({ + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, self.get_metrics_from_gauge(gauge)) + + def get_metrics_from_gauge(self, gauge): + results = {} + + for r in gauge.collect(): + results[r.name] = { + tuple(labels[x] for x in gauge.labels): value + for _, labels, value in r.samples + } + + return results diff --git a/tests/test_server.py b/tests/test_server.py index ef74544e93..4045fdadc3 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,14 +1,35 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging import re +from six import StringIO + from twisted.internet.defer import Deferred -from twisted.test.proto_helpers import MemoryReactorClock +from twisted.python.failure import Failure +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.web.resource import Resource +from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource +from synapse.http.site import SynapseSite, logger from synapse.util import Clock from tests import unittest -from tests.server import make_request, render, setup_test_homeserver +from tests.server import FakeTransport, make_request, render, setup_test_homeserver class JsonResourceTests(unittest.TestCase): @@ -121,3 +142,52 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + +class SiteTestCase(unittest.HomeserverTestCase): + def test_lose_connection(self): + """ + We log the URI correctly redacted when we lose the connection. + """ + + class HangingResource(Resource): + """ + A Resource that strategically hangs, as if it were processing an + answer. + """ + + def render(self, request): + return NOT_DONE_YET + + # Set up a logging handler that we can inspect afterwards + output = StringIO() + handler = logging.StreamHandler(output) + logger.addHandler(handler) + old_level = logger.level + logger.setLevel(10) + self.addCleanup(logger.setLevel, old_level) + self.addCleanup(logger.removeHandler, handler) + + # Make a resource and a Site, the resource will hang and allow us to + # time out the request while it's 'processing' + base_resource = Resource() + base_resource.putChild(b'', HangingResource()) + site = SynapseSite("test", "site_tag", {}, base_resource, "1.0") + + server = site.buildProtocol(None) + client = AccumulatingProtocol() + client.makeConnection(FakeTransport(server, self.reactor)) + server.makeConnection(FakeTransport(client, self.reactor)) + + # Send a request with an access token that will get redacted + server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n") + self.pump() + + # Lose the connection + e = Failure(Exception("Failed123")) + server.connectionLost(e) + handler.flush() + + # Our access token is redacted and the failure reason is logged. + self.assertIn("/?access_token=<redacted>", output.getvalue()) + self.assertIn("Failed123", output.getvalue()) diff --git a/tests/test_state.py b/tests/test_state.py index 96fdb8636c..e20c33322a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -18,7 +18,7 @@ from mock import Mock from twisted.internet import defer from synapse.api.auth import Auth -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RoomVersions from synapse.events import FrozenEvent from synapse.state import StateHandler, StateResolutionHandler @@ -117,6 +117,9 @@ class StateGroupStore(object): def register_event_id_state_group(self, event_id, state_group): self._event_to_state_group[event_id] = state_group + def get_room_version(self, room_id): + return RoomVersions.V1 + class DictObj(dict): def __init__(self, **kwargs): @@ -176,7 +179,9 @@ class StateTestCase(unittest.TestCase): def test_branch_no_conflict(self): graph = Graph( nodes={ - "START": DictObj(type=EventTypes.Create, state_key="", depth=1), + "START": DictObj( + type=EventTypes.Create, state_key="", content={}, depth=1 + ), "A": DictObj(type=EventTypes.Message, depth=2), "B": DictObj(type=EventTypes.Message, depth=3), "C": DictObj(type=EventTypes.Name, state_key="", depth=3), diff --git a/tests/test_types.py b/tests/test_types.py index be072d402b..0f5c8bfaf9 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -14,12 +14,12 @@ # limitations under the License. from synapse.api.errors import SynapseError -from synapse.server import HomeServer from synapse.types import GroupID, RoomAlias, UserID from tests import unittest +from tests.utils import TestHomeServer -mock_homeserver = HomeServer(hostname="my.domain") +mock_homeserver = TestHomeServer(hostname="my.domain") class UserIDTestCase(unittest.TestCase): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 45a78338d6..2eea3b098b 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -21,7 +21,7 @@ from synapse.events import FrozenEvent from synapse.visibility import filter_events_for_server import tests.unittest -from tests.utils import setup_test_homeserver +from tests.utils import create_room, setup_test_homeserver logger = logging.getLogger(__name__) @@ -36,6 +36,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_builder_factory = self.hs.get_event_builder_factory() self.store = self.hs.get_datastore() + yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") + @defer.inlineCallbacks def test_filtering(self): # @@ -94,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): events_to_filter.append(evt) # the erasey user gets erased - self.hs.get_datastore().mark_user_erased("@erased:local_hs") + yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") # ... and the filtering happens. filtered = yield filter_events_for_server( diff --git a/tests/unittest.py b/tests/unittest.py index d852e2465a..a59291cc60 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import hmac import logging from mock import Mock @@ -22,14 +24,17 @@ from canonicaljson import json import twisted import twisted.logger +from twisted.internet.defer import Deferred from twisted.trial import unittest from synapse.http.server import JsonResource +from synapse.http.site import SynapseRequest from synapse.server import HomeServer from synapse.types import UserID, create_requester from synapse.util.logcontext import LoggingContextFilter from tests.server import get_clock, make_request, render, setup_test_homeserver +from tests.utils import default_config # Set up putting Synapse's logs into Trial's. rootLogger = logging.getLogger() @@ -151,6 +156,7 @@ class HomeserverTestCase(TestCase): hijack_auth (bool): Whether to hijack auth to return the user specified in user_id. """ + servlets = [] hijack_auth = True @@ -217,7 +223,17 @@ class HomeserverTestCase(TestCase): Function to be overridden in subclasses. """ - raise NotImplementedError() + hs = self.setup_test_homeserver() + return hs + + def default_config(self, name="test"): + """ + Get a default HomeServer config object. + + Args: + name (str): The homeserver name/domain. + """ + return default_config(name) def prepare(self, reactor, clock, homeserver): """ @@ -234,7 +250,9 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ - def make_request(self, method, path, content=b""): + def make_request( + self, method, path, content=b"", access_token=None, request=SynapseRequest + ): """ Create a SynapseRequest at the path using the method and containing the given content. @@ -252,7 +270,7 @@ class HomeserverTestCase(TestCase): if isinstance(content, dict): content = json.dumps(content).encode('utf8') - return make_request(method, path, content) + return make_request(method, path, content, access_token, request) def render(self, request): """ @@ -279,3 +297,81 @@ class HomeserverTestCase(TestCase): kwargs = dict(kwargs) kwargs.update(self._hs_args) return setup_test_homeserver(self.addCleanup, *args, **kwargs) + + def pump(self, by=0.0): + """ + Pump the reactor enough that Deferreds will fire. + """ + self.reactor.pump([by] * 100) + + def get_success(self, d): + if not isinstance(d, Deferred): + return d + self.pump() + return self.successResultOf(d) + + def register_user(self, username, password, admin=False): + """ + Register a user. Requires the Admin API be registered. + + Args: + username (bytes/unicode): The user part of the new user. + password (bytes/unicode): The password of the new user. + admin (bool): Whether the user should be created as an admin + or not. + + Returns: + The MXID of the new user (unicode). + """ + self.hs.config.registration_shared_secret = u"shared" + + # Create the user + request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')]) + if admin: + nonce_str += b"\x00admin" + else: + nonce_str += b"\x00notadmin" + want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str) + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": username, + "password": password, + "admin": admin, + "mac": want_mac, + } + ) + request, channel = self.make_request( + "POST", "/_matrix/client/r0/admin/register", body.encode('utf8') + ) + self.render(request) + self.assertEqual(channel.code, 200) + + user_id = channel.json_body["user_id"] + return user_id + + def login(self, username, password, device_id=None): + """ + Log in a user, and get an access token. Requires the Login API be + registered. + + """ + body = {"type": "m.login.password", "user": username, "password": password} + if device_id: + body["device_id"] = device_id + + request, channel = self.make_request( + "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + ) + self.render(request) + self.assertEqual(channel.code, 200) + + access_token = channel.json_body["access_token"].encode('ascii') + return access_token diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 5cbada4eda..50bc7702d2 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -65,7 +65,6 @@ class ExpiringCacheTestCase(unittest.TestCase): def test_time_eviction(self): clock = MockClock() cache = ExpiringCache("test", clock, expiry_ms=1000) - cache.start() cache["key"] = 1 clock.advance_time(0.5) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 4633db77b3..8adaee3c8d 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -159,6 +159,11 @@ class LoggingContextTestCase(unittest.TestCase): self.assertEqual(r, "bum") self._check_test_key("one") + def test_nested_logging_context(self): + with LoggingContext(request="foo"): + nested_context = logcontext.nested_logging_context(suffix="bar") + self.assertEqual(nested_context.request, "foo-bar") + # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on diff --git a/tests/utils.py b/tests/utils.py index f1683e7a06..022a868501 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,7 +16,9 @@ import atexit import hashlib import os +import time import uuid +import warnings from inspect import getcallargs from mock import Mock, patch @@ -24,12 +26,14 @@ from six.moves.urllib import parse as urlparse from twisted.internet import defer, reactor +from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error +from synapse.config.server import ServerConfig from synapse.federation.transport import server from synapse.http.server import HttpServer from synapse.server import HomeServer -from synapse.storage import PostgresEngine -from synapse.storage.engines import create_engine +from synapse.storage import DataStore +from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import ( _get_or_create_schema_state, _setup_new_database, @@ -40,6 +44,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter # set this to True to run the tests against postgres instead of sqlite. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) +LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False) POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) @@ -91,10 +96,79 @@ def setupdb(): atexit.register(_cleanup) +def default_config(name): + """ + Create a reasonable test config. + """ + config = Mock() + config.signing_key = [MockKey()] + config.event_cache_size = 1 + config.enable_registration = True + config.macaroon_secret_key = "not even a little secret" + config.expire_access_token = False + config.server_name = name + config.trusted_third_party_id_servers = [] + config.room_invite_state_types = [] + config.password_providers = [] + config.worker_replication_url = "" + config.worker_app = None + config.email_enable_notifs = False + config.block_non_admin_invites = False + config.federation_domain_whitelist = None + config.federation_rc_reject_limit = 10 + config.federation_rc_sleep_limit = 10 + config.federation_rc_sleep_delay = 100 + config.federation_rc_concurrent = 10 + config.filter_timeline_limit = 5000 + config.user_directory_search_all_users = False + config.replicate_user_profiles_to = [] + config.user_consent_server_notice_content = None + config.block_events_without_consent_error = None + config.media_storage_providers = [] + config.autocreate_auto_join_rooms = True + config.auto_join_rooms = [] + config.limit_usage_by_mau = False + config.hs_disabled = False + config.hs_disabled_message = "" + config.hs_disabled_limit_type = "" + config.max_mau_value = 50 + config.mau_trial_days = 0 + config.mau_limits_reserved_threepids = [] + config.admin_contact = None + config.rc_messages_per_second = 10000 + config.rc_message_burst_count = 10000 + + config.use_frozen_dicts = False + + # we need a sane default_room_version, otherwise attempts to create rooms will + # fail. + config.default_room_version = "1" + + # disable user directory updates, because they get done in the + # background, which upsets the test runner. + config.update_user_directory = False + + def is_threepid_reserved(threepid): + return ServerConfig.is_threepid_reserved(config, threepid) + + config.is_threepid_reserved.side_effect = is_threepid_reserved + + return config + + +class TestHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + + @defer.inlineCallbacks def setup_test_homeserver( - cleanup_func, name="test", datastore=None, config=None, reactor=None, - homeserverToUse=HomeServer, **kargs + cleanup_func, + name="test", + datastore=None, + config=None, + reactor=None, + homeserverToUse=TestHomeServer, + **kargs ): """ Setup a homeserver suitable for running tests against. Keyword arguments @@ -110,49 +184,8 @@ def setup_test_homeserver( from twisted.internet import reactor if config is None: - config = Mock() - config.signing_key = [MockKey()] - config.event_cache_size = 1 - config.enable_registration = True - config.macaroon_secret_key = "not even a little secret" - config.expire_access_token = False - config.server_name = name - config.trusted_third_party_id_servers = [] - config.room_invite_state_types = [] - config.password_providers = [] - config.worker_replication_url = "" - config.worker_app = None - config.email_enable_notifs = False - config.block_non_admin_invites = False - config.federation_domain_whitelist = None - config.federation_rc_reject_limit = 10 - config.federation_rc_sleep_limit = 10 - config.federation_rc_sleep_delay = 100 - config.federation_rc_concurrent = 10 - config.filter_timeline_limit = 5000 - config.user_directory_search_all_users = False - config.replicate_user_profiles_to = [] - config.user_consent_server_notice_content = None - config.block_events_without_consent_error = None - config.media_storage_providers = [] - config.auto_join_rooms = [] - config.limit_usage_by_mau = False - config.hs_disabled = False - config.hs_disabled_message = "" - config.hs_disabled_limit_type = "" - config.max_mau_value = 50 - config.mau_limits_reserved_threepids = [] - config.admin_uri = None - - # we need a sane default_room_version, otherwise attempts to create rooms will - # fail. - config.default_room_version = "1" - - # disable user directory updates, because they get done in the - # background, which upsets the test runner. - config.update_user_directory = False - - config.use_frozen_dicts = True + config = default_config(name) + config.ldap_enabled = False if "clock" not in kargs: @@ -218,22 +251,44 @@ def setup_test_homeserver( else: # We need to do cleanup on PostgreSQL def cleanup(): + import psycopg2 + # Close all the db pools hs.get_db_pool().close() + dropped = False + # Drop the test database db_conn = db_engine.module.connect( database=POSTGRES_BASE_DB, user=POSTGRES_USER ) db_conn.autocommit = True cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for x in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + cur.close() db_conn.close() - # Register the cleanup hook - cleanup_func(cleanup) + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) hs.setup() else: @@ -307,7 +362,9 @@ class MockHttpResource(HttpServer): @patch('twisted.web.http.Request') @defer.inlineCallbacks - def trigger(self, http_method, path, content, mock_request, federation_auth=False): + def trigger( + self, http_method, path, content, mock_request, federation_auth_origin=None + ): """ Fire an HTTP event. Args: @@ -316,6 +373,7 @@ class MockHttpResource(HttpServer): content : The HTTP body mock_request : Mocked request to pass to the event so it can get content. + federation_auth_origin (bytes|None): domain to authenticate as, for federation Returns: A tuple of (code, response) Raises: @@ -336,8 +394,10 @@ class MockHttpResource(HttpServer): mock_request.getClientIP.return_value = "-" headers = {} - if federation_auth: - headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="] + if federation_auth_origin is not None: + headers[b"Authorization"] = [ + b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,) + ] mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) # return the right path if the event requires it @@ -540,3 +600,32 @@ class DeferredMockCallable(object): "Expected not to received any calls, got:\n" + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls]) ) + + +@defer.inlineCallbacks +def create_room(hs, room_id, creator_id): + """Creates and persist a creation event for the given room + + Args: + hs + room_id (str) + creator_id (str) + """ + + store = hs.get_datastore() + event_builder_factory = hs.get_event_builder_factory() + event_creation_handler = hs.get_event_creation_handler() + + builder = event_builder_factory.new( + { + "type": EventTypes.Create, + "state_key": "", + "sender": creator_id, + "room_id": room_id, + "content": {}, + } + ) + + event, context = yield event_creation_handler.create_new_client_event(builder) + + yield store.persist_event(event, context) |