diff options
Diffstat (limited to 'tests/rest/admin/test_registration_tokens.py')
-rw-r--r-- | tests/rest/admin/test_registration_tokens.py | 211 |
1 files changed, 164 insertions, 47 deletions
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 9bac423ae0..63087955f2 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -14,6 +14,7 @@ import random import string +from http import HTTPStatus import synapse.rest.admin from synapse.api.errors import Codes @@ -63,7 +64,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_create_no_auth(self): """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_create_requester_not_admin(self): @@ -74,7 +79,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_create_using_defaults(self): @@ -86,7 +95,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -110,7 +119,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) @@ -131,7 +140,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -149,7 +158,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_token_invalid_chars(self): @@ -165,7 +178,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_token_already_exists(self): @@ -180,7 +197,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) + self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body) channel2 = self.make_request( "POST", @@ -188,7 +205,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) def test_create_unable_to_generate_token(self): @@ -220,7 +237,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(500, channel.code, msg=channel.json_body) def test_create_uses_allowed(self): """Check you can only create a token with good values for uses_allowed.""" @@ -231,7 +248,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) # Should fail with negative integer @@ -241,7 +258,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -251,7 +272,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_expiry_time(self): @@ -263,7 +288,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() - 10000}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -273,7 +302,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() + 1000000.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_length(self): @@ -285,7 +318,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 64}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 64) # Should fail with 0 @@ -295,7 +328,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -305,7 +342,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": -5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a float @@ -315,7 +356,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 8.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with 65 @@ -325,7 +370,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 65}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # UPDATING @@ -337,7 +386,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_update_requester_not_admin(self): @@ -348,7 +401,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_update_non_existent(self): @@ -360,7 +417,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_update_uses_allowed(self): @@ -375,7 +436,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertIsNone(channel.json_body["expiry_time"]) @@ -386,7 +447,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertIsNone(channel.json_body["expiry_time"]) @@ -397,7 +458,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": None}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -408,7 +469,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -418,7 +483,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_update_expiry_time(self): @@ -434,7 +503,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -445,7 +514,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": None}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -457,7 +526,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": past_time}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail a float @@ -467,7 +540,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time + 0.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_update_both(self): @@ -488,7 +565,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) @@ -509,7 +586,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # DELETING @@ -521,7 +602,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_delete_requester_not_admin(self): @@ -532,7 +617,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_delete_non_existent(self): @@ -544,7 +633,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_delete(self): @@ -559,7 +652,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # GETTING ONE @@ -570,7 +663,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_get_requester_not_admin(self): @@ -581,7 +678,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_get_non_existent(self): @@ -593,7 +694,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_get(self): @@ -608,7 +713,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -620,7 +725,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_list_no_auth(self): """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_list_requester_not_admin(self): @@ -631,7 +740,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_list_all(self): @@ -646,7 +759,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 1) token_info = channel.json_body["registration_tokens"][0] self.assertEqual(token_info["token"], token) @@ -664,7 +777,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) def _test_list_query_parameter(self, valid: str): """Helper used to test both valid=true and valid=false.""" @@ -696,7 +813,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 2) token_info_1 = channel.json_body["registration_tokens"][0] token_info_2 = channel.json_body["registration_tokens"][1] |