diff options
Diffstat (limited to 'tests/rest/admin/test_registration_tokens.py')
-rw-r--r-- | tests/rest/admin/test_registration_tokens.py | 288 |
1 files changed, 84 insertions, 204 deletions
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 350a62dda6..9bac423ae0 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -11,17 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import random import string -from http import HTTPStatus - -from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login -from synapse.server import HomeServer -from synapse.util import Clock from tests import unittest @@ -32,7 +28,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -42,7 +38,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/registration_tokens" - def _new_token(self, **kwargs) -> str: + def _new_token(self, **kwargs): """Helper function to create a token.""" token = kwargs.get( "token", @@ -64,17 +60,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # CREATION - def test_create_no_auth(self) -> None: + def test_create_no_auth(self): """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_create_requester_not_admin(self) -> None: + def test_create_requester_not_admin(self): """Try to create a token while not an admin.""" channel = self.make_request( "POST", @@ -82,14 +74,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_create_using_defaults(self) -> None: + def test_create_using_defaults(self): """Create a token using all the defaults.""" channel = self.make_request( "POST", @@ -98,14 +86,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_specifying_fields(self) -> None: + def test_create_specifying_fields(self): """Create a token specifying the value of all fields.""" # As many of the allowed characters as possible with length <= 64 token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-" @@ -122,14 +110,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_with_null_value(self) -> None: + def test_create_with_null_value(self): """Create a token specifying unlimited uses and no expiry.""" data = { "uses_allowed": None, @@ -143,14 +131,14 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_token_too_long(self) -> None: + def test_create_token_too_long(self): """Check token longer than 64 chars is invalid.""" data = {"token": "a" * 65} @@ -161,14 +149,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_invalid_chars(self) -> None: + def test_create_token_invalid_chars(self): """Check you can't create token with invalid characters.""" data = { "token": "abc/def", @@ -181,14 +165,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_already_exists(self) -> None: + def test_create_token_already_exists(self): """Check you can't create token that already exists.""" data = { "token": "abcd", @@ -200,7 +180,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body) + self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) channel2 = self.make_request( "POST", @@ -208,10 +188,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) + self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_unable_to_generate_token(self) -> None: + def test_create_unable_to_generate_token(self): """Check right error is raised when server can't generate unique token.""" # Create all possible single character tokens tokens = [] @@ -240,9 +220,9 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(500, channel.code, msg=channel.json_body) + self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) - def test_create_uses_allowed(self) -> None: + def test_create_uses_allowed(self): """Check you can only create a token with good values for uses_allowed.""" # Should work with 0 (token is invalid from the start) channel = self.make_request( @@ -251,7 +231,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 0) # Should fail with negative integer @@ -261,11 +241,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -275,14 +251,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_expiry_time(self) -> None: + def test_create_expiry_time(self): """Check you can't create a token with an invalid expiry_time.""" # Should fail with a time in the past channel = self.make_request( @@ -291,11 +263,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() - 10000}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -305,14 +273,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() + 1000000.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_length(self) -> None: + def test_create_length(self): """Check you can only generate a token with a valid length.""" # Should work with 64 channel = self.make_request( @@ -321,7 +285,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 64}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["token"]), 64) # Should fail with 0 @@ -331,11 +295,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 0}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -345,11 +305,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a float @@ -359,11 +315,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 8.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with 65 @@ -373,30 +325,22 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 65}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # UPDATING - def test_update_no_auth(self) -> None: + def test_update_no_auth(self): """Try to update a token without authentication.""" channel = self.make_request( "PUT", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_update_requester_not_admin(self) -> None: + def test_update_requester_not_admin(self): """Try to update a token while not an admin.""" channel = self.make_request( "PUT", @@ -404,14 +348,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_update_non_existent(self) -> None: + def test_update_non_existent(self): """Try to update a token that doesn't exist.""" channel = self.make_request( "PUT", @@ -420,14 +360,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_update_uses_allowed(self) -> None: + def test_update_uses_allowed(self): """Test updating just uses_allowed.""" # Create new token using default values token = self._new_token() @@ -439,7 +375,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertIsNone(channel.json_body["expiry_time"]) @@ -450,7 +386,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertIsNone(channel.json_body["expiry_time"]) @@ -461,7 +397,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -472,11 +408,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -486,14 +418,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_expiry_time(self) -> None: + def test_update_expiry_time(self): """Test updating just expiry_time.""" # Create new token using default values token = self._new_token() @@ -506,7 +434,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -517,7 +445,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -529,11 +457,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": past_time}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail a float @@ -543,14 +467,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time + 0.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_both(self) -> None: + def test_update_both(self): """Test updating both uses_allowed and expiry_time.""" # Create new token using default values token = self._new_token() @@ -568,11 +488,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) - def test_update_invalid_type(self) -> None: + def test_update_invalid_type(self): """Test using invalid types doesn't work.""" # Create new token using default values token = self._new_token() @@ -589,30 +509,22 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # DELETING - def test_delete_no_auth(self) -> None: + def test_delete_no_auth(self): """Try to delete a token without authentication.""" channel = self.make_request( "DELETE", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_delete_requester_not_admin(self) -> None: + def test_delete_requester_not_admin(self): """Try to delete a token while not an admin.""" channel = self.make_request( "DELETE", @@ -620,14 +532,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_delete_non_existent(self) -> None: + def test_delete_non_existent(self): """Try to delete a token that doesn't exist.""" channel = self.make_request( "DELETE", @@ -636,14 +544,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_delete(self) -> None: + def test_delete(self): """Test deleting a token.""" # Create new token using default values token = self._new_token() @@ -655,25 +559,21 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # GETTING ONE - def test_get_no_auth(self) -> None: + def test_get_no_auth(self): """Try to get a token without authentication.""" channel = self.make_request( "GET", self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_get_requester_not_admin(self) -> None: + def test_get_requester_not_admin(self): """Try to get a token while not an admin.""" channel = self.make_request( "GET", @@ -681,14 +581,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_get_non_existent(self) -> None: + def test_get_non_existent(self): """Try to get a token that doesn't exist.""" channel = self.make_request( "GET", @@ -697,14 +593,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_get(self) -> None: + def test_get(self): """Test getting a token.""" # Create new token using default values token = self._new_token() @@ -716,7 +608,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["token"], token) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -725,17 +617,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # LISTING - def test_list_no_auth(self) -> None: + def test_list_no_auth(self): """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_list_requester_not_admin(self) -> None: + def test_list_requester_not_admin(self): """Try to list tokens while not an admin.""" channel = self.make_request( "GET", @@ -743,14 +631,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_list_all(self) -> None: + def test_list_all(self): """Test listing all tokens.""" # Create new token using default values token = self._new_token() @@ -762,7 +646,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["registration_tokens"]), 1) token_info = channel.json_body["registration_tokens"][0] self.assertEqual(token_info["token"], token) @@ -771,7 +655,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(token_info["pending"], 0) self.assertEqual(token_info["completed"], 0) - def test_list_invalid_query_parameter(self) -> None: + def test_list_invalid_query_parameter(self): """Test with `valid` query parameter not `true` or `false`.""" channel = self.make_request( "GET", @@ -780,13 +664,9 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - def _test_list_query_parameter(self, valid: str) -> None: + def _test_list_query_parameter(self, valid: str): """Helper used to test both valid=true and valid=false.""" # Create 2 valid and 2 invalid tokens. now = self.hs.get_clock().time_msec() @@ -816,17 +696,17 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(len(channel.json_body["registration_tokens"]), 2) token_info_1 = channel.json_body["registration_tokens"][0] token_info_2 = channel.json_body["registration_tokens"][1] self.assertIn(token_info_1["token"], tokens) self.assertIn(token_info_2["token"], tokens) - def test_list_valid(self) -> None: + def test_list_valid(self): """Test listing just valid tokens.""" self._test_list_query_parameter(valid="true") - def test_list_invalid(self) -> None: + def test_list_invalid(self): """Test listing just invalid tokens.""" self._test_list_query_parameter(valid="false") |