diff options
author | Richard van der Hoff <richard@matrix.org> | 2021-02-17 16:31:57 +0000 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2021-02-17 16:31:57 +0000 |
commit | 7b7831bb6363a625c97446298838c66abfeb6b8b (patch) | |
tree | 39ec72cc7b8985858012f5d77fb89796fb04ff43 /tests/rest/client | |
parent | Ensure that we never stop reconnecting to redis (#9391) (diff) | |
parent | Reorganize CONTRIBUTING.md documentation. (#9281) (diff) | |
download | synapse-7b7831bb6363a625c97446298838c66abfeb6b8b.tar.xz |
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'tests/rest/client')
-rw-r--r-- | tests/rest/client/test_power_levels.py | 4 | ||||
-rw-r--r-- | tests/rest/client/test_redactions.py | 3 | ||||
-rw-r--r-- | tests/rest/client/test_retention.py | 3 | ||||
-rw-r--r-- | tests/rest/client/test_shadow_banned.py | 10 | ||||
-rw-r--r-- | tests/rest/client/v1/test_events.py | 4 | ||||
-rw-r--r-- | tests/rest/client/v1/test_login.py | 34 | ||||
-rw-r--r-- | tests/rest/client/v1/test_profile.py | 249 | ||||
-rw-r--r-- | tests/rest/client/v1/test_rooms.py | 16 | ||||
-rw-r--r-- | tests/rest/client/v1/test_typing.py | 32 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 36 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_account.py | 76 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_auth.py | 24 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_password_policy.py | 20 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_relations.py | 27 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_sync.py | 23 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_upgrade_room.py | 161 |
16 files changed, 439 insertions, 283 deletions
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index 913ea3c98e..5256c11fe6 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -73,7 +73,9 @@ class PowerLevelsTestCase(HomeserverTestCase): # Mod the mod room_power_levels = self.helper.get_state( - self.room_id, "m.room.power_levels", tok=self.admin_access_token, + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, ) # Update existing power levels with mod at PL50 diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index f0707646bb..e0c74591b6 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -181,8 +181,7 @@ class RedactionsTestCase(HomeserverTestCase): ) def test_redact_event_as_moderator_ratelimit(self): - """Tests that the correct ratelimiting is applied to redactions - """ + """Tests that the correct ratelimiting is applied to redactions""" message_ids = [] # as a regular user, send messages to redact diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 31dc832fd5..aee99bb6a0 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -250,7 +250,8 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["backfill"]) self.hs = self.setup_test_homeserver( - config=config, federation_client=mock_federation_client, + config=config, + federation_client=mock_federation_client, ) return self.hs diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 0ebdf1415b..d2cce44032 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -260,7 +260,10 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + self.banned_user_id, + room_id, + "m.room.member", + self.banned_user_id, ) ) self.assertEqual( @@ -292,7 +295,10 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + self.banned_user_id, + room_id, + "m.room.member", + self.banned_user_id, ) ) self.assertEqual( diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 0a5ca317ea..2ae896db1e 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -150,6 +150,8 @@ class GetEventsTestCase(unittest.HomeserverTestCase): event_id = resp["event_id"] channel = self.make_request( - "GET", "/events/" + event_id, access_token=self.token, + "GET", + "/events/" + event_id, + access_token=self.token, ) self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index bfcb786af8..fb29eaed6f 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,7 +15,7 @@ import time import urllib.parse -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union from urllib.parse import urlencode from mock import Mock @@ -493,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class + html = channel.result["body"].decode("utf-8") p = TestHtmlParser() - p.feed(channel.result["body"].decode("utf-8")) + p.feed(html) p.close() - self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"]) + # there should be a link for each href + returned_idps = [] # type: List[str] + for link in p.links: + path, query = link.split("?", 1) + self.assertEqual(path, "pick_idp") + params = urllib.parse.parse_qs(query) + self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL]) + returned_idps.append(params["idp"][0]) - self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) + self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) def test_multi_sso_redirect_to_cas(self): """If CAS is chosen, should redirect to the CAS server""" @@ -603,7 +611,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # matrix access token, mxid, and device id. login_token = params[2][1] chan = self.make_request( - "POST", "/login", content={"type": "m.login.token", "token": login_token}, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") @@ -611,7 +621,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_multi_sso_redirect_to_unknown(self): """An unknown IdP should cause a 400""" channel = self.make_request( - "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + "GET", + "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", ) self.assertEqual(channel.code, 400, channel.result) @@ -711,7 +722,8 @@ class CASTestCase(unittest.HomeserverTestCase): mocked_http_client.get_raw.side_effect = get_raw self.hs = self.setup_test_homeserver( - config=config, proxied_http_client=mocked_http_client, + config=config, + proxied_http_client=mocked_http_client, ) return self.hs @@ -1236,7 +1248,9 @@ class UsernamePickerTestCase(HomeserverTestCase): # looks ok. username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions self.assertIn( - session_id, username_mapping_sessions, "session id not found in map", + session_id, + username_mapping_sessions, + "session id not found in map", ) session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") @@ -1291,7 +1305,9 @@ class UsernamePickerTestCase(HomeserverTestCase): # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. chan = self.make_request( - "POST", "/login", content={"type": "m.login.token", "token": login_token}, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index e59fa70baa..f3448c94dd 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,163 +14,11 @@ # limitations under the License. """Tests REST events for /profile paths.""" -import json - -from mock import Mock - -from twisted.internet import defer - -import synapse.types -from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin from synapse.rest.client.v1 import login, profile, room from tests import unittest -from ....utils import MockHttpResource, setup_test_homeserver - -myid = "@1234ABCD:test" -PATH_PREFIX = "/_matrix/client/r0" - - -class MockHandlerProfileTestCase(unittest.TestCase): - """ Tests rest layer of profile management. - - Todo: move these into ProfileTestCase - """ - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.mock_handler = Mock( - spec=[ - "get_displayname", - "set_displayname", - "get_avatar_url", - "set_avatar_url", - "check_profile_query_allowed", - ] - ) - - self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) - self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) - self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) - self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) - self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( - Mock() - ) - - hs = yield setup_test_homeserver( - self.addCleanup, - "test", - federation_http_client=None, - resource_for_client=self.mock_resource, - federation=Mock(), - federation_client=Mock(), - profile_handler=self.mock_handler, - ) - - async def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) - - hs.get_auth().get_user_by_req = _get_user_by_req - - profile.register_servlets(hs, self.mock_resource) - - @defer.inlineCallbacks - def test_get_my_name(self): - mocked_get = self.mock_handler.get_displayname - mocked_get.return_value = defer.succeed("Frank") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/displayname" % (myid), None - ) - - self.assertEquals(200, code) - self.assertEquals({"displayname": "Frank"}, response) - self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD") - - @defer.inlineCallbacks - def test_set_my_name(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.return_value = defer.succeed(()) - - (code, response) = yield self.mock_resource.trigger( - "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}' - ) - - self.assertEquals(200, code) - self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.") - - @defer.inlineCallbacks - def test_set_my_name_noauth(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.side_effect = AuthError(400, "message") - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/displayname" % ("@4567:test"), - b'{"displayname": "Frank Jr."}', - ) - - self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code)) - - @defer.inlineCallbacks - def test_get_other_name(self): - mocked_get = self.mock_handler.get_displayname - mocked_get.return_value = defer.succeed("Bob") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None - ) - - self.assertEquals(200, code) - self.assertEquals({"displayname": "Bob"}, response) - - @defer.inlineCallbacks - def test_set_other_name(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.side_effect = SynapseError(400, "message") - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/displayname" % ("@opaque:elsewhere"), - b'{"displayname":"bob"}', - ) - - self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code)) - - @defer.inlineCallbacks - def test_get_my_avatar(self): - mocked_get = self.mock_handler.get_avatar_url - mocked_get.return_value = defer.succeed("http://my.server/me.png") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/avatar_url" % (myid), None - ) - - self.assertEquals(200, code) - self.assertEquals({"avatar_url": "http://my.server/me.png"}, response) - self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD") - - @defer.inlineCallbacks - def test_set_my_avatar(self): - mocked_set = self.mock_handler.set_avatar_url - mocked_set.return_value = defer.succeed(()) - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/avatar_url" % (myid), - b'{"avatar_url": "http://my.server/pic.gif"}', - ) - - self.assertEquals(200, code) - self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") - class ProfileTestCase(unittest.HomeserverTestCase): @@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") + self.other = self.register_user("other", "pass", displayname="Bob") + + def test_get_displayname(self): + res = self._get_displayname() + self.assertEqual(res, "owner") def test_set_displayname(self): channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), - content=json.dumps({"displayname": "test"}), + content={"displayname": "test"}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 200, channel.result) - res = self.get_displayname() + res = self._get_displayname() self.assertEqual(res, "test") + def test_set_displayname_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + ) + self.assertEqual(channel.code, 401, channel.result) + def test_set_displayname_too_long(self): """Attempts to set a stupid displayname should get a 400""" channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), - content=json.dumps({"displayname": "test" * 100}), + content={"displayname": "test" * 100}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 400, channel.result) - res = self.get_displayname() + res = self._get_displayname() self.assertEqual(res, "owner") - def get_displayname(self): - channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,)) + def test_get_displayname_other(self): + res = self._get_displayname(self.other) + self.assertEquals(res, "Bob") + + def test_set_displayname_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.other,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_get_avatar_url(self): + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_set_avatar_url(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_avatar_url() + self.assertEqual(res, "http://my.server/pic.gif") + + def test_set_avatar_url_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_avatar_url_too_long(self): + """Attempts to set a stupid avatar_url should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_get_avatar_url_other(self): + res = self._get_avatar_url(self.other) + self.assertIsNone(res) + + def test_set_avatar_url_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.other,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def _get_displayname(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/displayname" % (name or self.owner,) + ) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] + def _get_avatar_url(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/avatar_url" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body.get("avatar_url") + class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 2548b3a80c..ed65f645fc 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -46,7 +46,9 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", federation_http_client=None, federation_client=Mock(), + "red", + federation_http_client=None, + federation_client=Mock(), ) self.hs.get_federation_handler = Mock() @@ -1480,7 +1482,9 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), 2, [result["result"]["content"] for result in results], + len(results), + 2, + [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1515,7 +1519,9 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), 4, [result["result"]["content"] for result in results], + len(results), + 4, + [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1562,7 +1568,9 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), 1, [result["result"]["content"] for result in results], + len(results), + 1, + [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 38c51525a3..329dbd06de 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -18,8 +18,6 @@ from mock import Mock -from twisted.internet import defer - from synapse.rest.client.v1 import room from synapse.types import UserID @@ -39,7 +37,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", federation_http_client=None, federation_client=Mock(), + "red", + federation_http_client=None, + federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] @@ -60,32 +60,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_datastore().insert_client_ip = _insert_client_ip - def get_room_members(room_id): - if room_id == self.room_id: - return defer.succeed([self.user]) - else: - return defer.succeed([]) - - @defer.inlineCallbacks - def fetch_room_distributions_into( - room_id, localusers=None, remotedomains=None, ignore_user=None - ): - members = yield get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if hs.is_mine(member): - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - - hs.get_room_member_handler().fetch_room_distributions_into = ( - fetch_room_distributions_into - ) - return hs def prepare(self, reactor, clock, hs): diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index b1333df82d..8231a423f3 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -166,9 +166,12 @@ class RestHelper: json.dumps(data).encode("utf8"), ) - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], ) self.auth_user_id = temp_id @@ -201,9 +204,12 @@ class RestHelper: json.dumps(content).encode("utf8"), ) - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], ) return channel.json_body @@ -251,9 +257,12 @@ class RestHelper: channel = make_request(self.hs.get_reactor(), self.site, method, path, content) - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], ) return channel.json_body @@ -447,7 +456,10 @@ class RestHelper: return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) def complete_oidc_auth( - self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + self, + oauth_uri: str, + cookies: Mapping[str, str], + user_info_dict: JsonDict, ) -> FakeChannel: """Mock out an OIDC authentication flow @@ -491,7 +503,9 @@ class RestHelper: (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( - code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), + code=200, + phrase=b"OK", + body=json.dumps(resp_obj).encode("utf-8"), ) return resp diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 177dc476da..e72b61963d 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -75,8 +75,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): - """Test basic password reset flow - """ + """Test basic password reset flow""" old_password = "monkey" new_password = "kangeroo" @@ -114,8 +113,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): @override_config({"rc_3pid_validation": {"burst_count": 3}}) def test_ratelimit_by_email(self): - """Test that we ratelimit /requestToken for the same email. - """ + """Test that we ratelimit /requestToken for the same email.""" old_password = "monkey" new_password = "kangeroo" @@ -203,8 +201,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.attempt_wrong_password_login("kermit", old_password) def test_cant_reset_password_without_clicking_link(self): - """Test that we do actually need to click the link in the email - """ + """Test that we do actually need to click the link in the email""" old_password = "monkey" new_password = "kangeroo" @@ -299,7 +296,9 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): if channel.code != 200: raise HttpResponseException( - channel.code, channel.result["reason"], channel.result["body"], + channel.code, + channel.result["reason"], + channel.result["body"], ) return channel.json_body["sid"] @@ -566,8 +565,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): @override_config({"rc_3pid_validation": {"burst_count": 3}}) def test_ratelimit_by_ip(self): - """Tests that adding emails is ratelimited by IP - """ + """Tests that adding emails is ratelimited by IP""" # We expect to be able to set three emails before getting ratelimited. self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) @@ -580,8 +578,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(cm.exception.code, 429) def test_add_email_if_disabled(self): - """Test adding email to profile when doing so is disallowed - """ + """Test adding email to profile when doing so is disallowed""" self.hs.config.enable_3pid_changes = False client_secret = "foobar" @@ -611,15 +608,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email(self): - """Test deleting an email from profile - """ + """Test deleting an email from profile""" # Add a threepid self.get_success( self.store.user_add_threepid( @@ -641,15 +639,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email_if_disabled(self): - """Test deleting an email from profile when disallowed - """ + """Test deleting an email from profile when disallowed""" self.hs.config.enable_3pid_changes = False # Add a threepid @@ -675,7 +674,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -683,8 +684,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) def test_cant_add_email_without_clicking_link(self): - """Test that we do actually need to click the link in the email - """ + """Test that we do actually need to click the link in the email""" client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -710,7 +710,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -743,7 +745,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -788,7 +792,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Ensure not providing a next_link parameter still works self._request_token( - "something@example.com", "some_secret", next_link=None, expect_code=200, + "something@example.com", + "some_secret", + next_link=None, + expect_code=200, ) self._request_token( @@ -846,17 +853,27 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): if next_link: body["next_link"] = next_link - channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) + channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + body, + ) if channel.code != expect_code: raise HttpResponseException( - channel.code, channel.result["reason"], channel.result["body"], + channel.code, + channel.result["reason"], + channel.result["body"], ) return channel.json_body.get("sid") def _request_token_invalid_email( - self, email, expected_errcode, expected_error, client_secret="foobar", + self, + email, + expected_errcode, + expected_error, + client_secret="foobar", ): channel = self.make_request( "POST", @@ -895,8 +912,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): return match.group(0) def _add_email(self, request_email, expected_email): - """Test adding an email to profile - """ + """Test adding an email to profile""" previous_email_attempts = len(self.email_attempts) client_secret = "foobar" @@ -926,7 +942,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 3f50c56745..501f09203f 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -102,7 +102,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"}, + 401, + {"username": "user", "type": "m.login.password", "password": "bar"}, ) # Grab the session @@ -191,7 +192,10 @@ class UIAuthTests(unittest.HomeserverTestCase): ) -> FakeChannel: """Delete an individual device.""" channel = self.make_request( - "DELETE", "devices/" + device, body, access_token=access_token, + "DELETE", + "devices/" + device, + body, + access_token=access_token, ) # Ensure the response is sane. @@ -204,7 +208,10 @@ class UIAuthTests(unittest.HomeserverTestCase): # Note that this uses the delete_devices endpoint so that we can modify # the payload half-way through some tests. channel = self.make_request( - "POST", "delete_devices", body, access_token=self.user_tok, + "POST", + "delete_devices", + body, + access_token=self.user_tok, ) # Ensure the response is sane. @@ -417,7 +424,10 @@ class UIAuthTests(unittest.HomeserverTestCase): # and now the delete request should succeed. self.delete_device( - self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}}, + self.user_tok, + self.device_id, + 200, + body={"auth": {"session": session_id}}, ) @skip_unless(HAS_OIDC, "requires OIDC") @@ -443,8 +453,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self): - """A user that had a password and then logged in with SSO should get both flows - """ + """A user that had a password and then logged in with SSO should get both flows""" login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) @@ -459,8 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self): - """If the user tries to authenticate with the wrong SSO user, they get an error - """ + """If the user tries to authenticate with the wrong SSO user, they get an error""" # log the user in login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index fba34def30..5ebc5707a5 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -91,7 +91,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_TOO_SHORT, + channel.result, ) def test_password_no_digit(self): @@ -100,7 +102,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_DIGIT, + channel.result, ) def test_password_no_symbol(self): @@ -109,7 +113,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_SYMBOL, + channel.result, ) def test_password_no_uppercase(self): @@ -118,7 +124,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_UPPERCASE, + channel.result, ) def test_password_no_lowercase(self): @@ -127,7 +135,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_LOWERCASE, + channel.result, ) def test_password_compliant(self): diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index bd574077e7..7c457754f1 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -83,14 +83,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) def test_deny_membership(self): - """Test that we deny relations on membership events - """ + """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) def test_deny_double_react(self): - """Test that we deny relations on membership events - """ + """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) @@ -98,8 +96,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): - """Tests that calling pagination API correctly the latest relations. - """ + """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) @@ -174,8 +171,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(found_event_ids, expected_event_ids) def test_aggregation_pagination_groups(self): - """Test that we can paginate annotation groups correctly. - """ + """Test that we can paginate annotation groups correctly.""" # We need to create ten separate users to send each reaction. access_tokens = [self.user_token, self.user2_token] @@ -240,8 +236,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(sent_groups, found_groups) def test_aggregation_pagination_within_group(self): - """Test that we can paginate within an annotation group. - """ + """Test that we can paginate within an annotation group.""" # We need to create ten separate users to send each reaction. access_tokens = [self.user_token, self.user2_token] @@ -311,8 +306,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(found_event_ids, expected_event_ids) def test_aggregation(self): - """Test that annotations get correctly aggregated. - """ + """Test that annotations get correctly aggregated.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -344,8 +338,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) def test_aggregation_redactions(self): - """Test that annotations get correctly aggregated after a redaction. - """ + """Test that annotations get correctly aggregated after a redaction.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -379,8 +372,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) def test_aggregation_must_be_annotation(self): - """Test that aggregations must be annotations. - """ + """Test that aggregations must be annotations.""" channel = self.make_request( "GET", @@ -437,8 +429,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) def test_edit(self): - """Test that a simple edit works. - """ + """Test that a simple edit works.""" new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 512e36c236..2dbf42397a 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -388,13 +388,19 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Check that room name changes increase the unread counter. self.helper.send_state( - self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + self.room_id, + "m.room.name", + {"name": "my super room"}, + tok=self.tok2, ) self._check_unread_count(1) # Check that room topic changes increase the unread counter. self.helper.send_state( - self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + self.room_id, + "m.room.topic", + {"topic": "welcome!!!"}, + tok=self.tok2, ) self._check_unread_count(2) @@ -404,7 +410,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Check that custom events with a body increase the unread counter. self.helper.send_event( - self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + self.room_id, + "org.matrix.custom_type", + {"body": "hello"}, + tok=self.tok2, ) self._check_unread_count(4) @@ -443,14 +452,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): """Syncs and compares the unread count with the expected value.""" channel = self.make_request( - "GET", self.url % self.next_batch, access_token=self.tok, + "GET", + self.url % self.next_batch, + access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) room_entry = channel.json_body["rooms"]["join"][self.room_id] self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + room_entry["org.matrix.msc2654.unread_count"], + expected_count, + room_entry, ) # Store the next batch for the next request. diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py new file mode 100644 index 0000000000..d890d11863 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_upgrade_room.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet + +from tests import unittest +from tests.server import FakeChannel + + +class UpgradeRoomTest(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + self.creator = self.register_user("creator", "pass") + self.creator_token = self.login(self.creator, "pass") + + self.other = self.register_user("user", "pass") + self.other_token = self.login(self.other, "pass") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) + self.helper.join(self.room_id, self.other, tok=self.other_token) + + def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: + # We never want a cached response. + self.reactor.advance(5 * 60 + 1) + + return self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, + # This will upgrade a room to the same version, but that's fine. + content={"new_version": DEFAULT_ROOM_VERSION}, + access_token=token or self.creator_token, + ) + + def test_upgrade(self): + """ + Upgrading a room should work fine. + """ + channel = self._upgrade_room() + self.assertEquals(200, channel.code, channel.result) + self.assertIn("replacement_room", channel.json_body) + + def test_not_in_room(self): + """ + Upgrading a room should work fine. + """ + # THe user isn't in the room. + roomless = self.register_user("roomless", "pass") + roomless_token = self.login(roomless, "pass") + + channel = self._upgrade_room(roomless_token) + self.assertEquals(403, channel.code, channel.result) + + def test_power_levels(self): + """ + Another user can upgrade the room if their power level is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users"][self.other] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_user_default(self): + """ + Another user can upgrade the room if the default power level for users is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users_default"] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_tombstone(self): + """ + Another user can upgrade the room if they can send the tombstone event. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["events"]["m.room.tombstone"] = 0 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + self.assertNotIn(self.other, power_levels["users"]) |