summary refs log tree commit diff
path: root/tests/rest/admin/test_user.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/admin/test_user.py')
-rw-r--r--tests/rest/admin/test_user.py185
1 files changed, 162 insertions, 23 deletions
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py

index f09f66da00..5f73dbdc4a 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -328,6 +328,61 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) + @override_config( + { + "user_types": { + "extra_user_types": ["extra1", "extra2"], + } + } + ) + def test_extra_user_type(self) -> None: + """ + Check that the extra user type can be used when registering a user. + """ + + def nonce_mac(user_type: str) -> tuple[str, str]: + """ + Get a nonce and the expected HMAC for that nonce. + """ + channel = self.make_request("GET", self.url) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update( + nonce.encode("ascii") + + b"\x00alice\x00abc123\x00notadmin\x00" + + user_type.encode("ascii") + ) + want_mac_str = want_mac.hexdigest() + + return nonce, want_mac_str + + nonce, mac = nonce_mac("extra1") + # Valid user_type + body = { + "nonce": nonce, + "username": "alice", + "password": "abc123", + "user_type": "extra1", + "mac": mac, + } + channel = self.make_request("POST", self.url, body) + self.assertEqual(200, channel.code, msg=channel.json_body) + + nonce, mac = nonce_mac("extra3") + # Invalid user_type + body = { + "nonce": nonce, + "username": "alice", + "password": "abc123", + "user_type": "extra3", + "mac": mac, + } + channel = self.make_request("POST", self.url, body) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Invalid user type", channel.json_body["error"]) + def test_displayname(self) -> None: """ Test that displayname of new user is set @@ -1186,6 +1241,80 @@ class UsersListTestCase(unittest.HomeserverTestCase): not_user_types=["custom"], ) + @override_config( + { + "user_types": { + "extra_user_types": ["extra1", "extra2"], + } + } + ) + def test_filter_not_user_types_with_extra(self) -> None: + """Tests that the endpoint handles the not_user_types param when extra_user_types are configured""" + + regular_user_id = self.register_user("normalo", "secret") + + extra1_user_id = self.register_user("extra1", "secret") + self.make_request( + "PUT", + "/_synapse/admin/v2/users/" + urllib.parse.quote(extra1_user_id), + {"user_type": "extra1"}, + access_token=self.admin_user_tok, + ) + + def test_user_type( + expected_user_ids: List[str], not_user_types: Optional[List[str]] = None + ) -> None: + """Runs a test for the not_user_types param + Args: + expected_user_ids: Ids of the users that are expected to be returned + not_user_types: List of values for the not_user_types param + """ + + user_type_query = "" + + if not_user_types is not None: + user_type_query = "&".join( + [f"not_user_type={u}" for u in not_user_types] + ) + + test_url = f"{self.url}?{user_type_query}" + channel = self.make_request( + "GET", + test_url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code) + self.assertEqual(channel.json_body["total"], len(expected_user_ids)) + self.assertEqual( + expected_user_ids, + [u["name"] for u in channel.json_body["users"]], + ) + + # Request without user_types → all users expected + test_user_type([self.admin_user, extra1_user_id, regular_user_id]) + + # Request and exclude extra1 user type + test_user_type( + [self.admin_user, regular_user_id], + not_user_types=["extra1"], + ) + + # Request and exclude extra1 and extra2 user types + test_user_type( + [self.admin_user, regular_user_id], + not_user_types=["extra1", "extra2"], + ) + + # Request and exclude empty user types → only expected the extra1 user + test_user_type([extra1_user_id], not_user_types=[""]) + + # Request and exclude an unregistered type → expect all users + test_user_type( + [self.admin_user, extra1_user_id, regular_user_id], + not_user_types=["extra3"], + ) + def test_erasure_status(self) -> None: # Create a new user. user_id = self.register_user("eraseme", "eraseme") @@ -2977,22 +3106,18 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) - def test_set_user_type(self) -> None: - """ - Test changing user type. - """ - - # Set to support type + def set_user_type(self, user_type: Optional[str]) -> None: + # Set to user_type channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content={"user_type": UserTypes.SUPPORT}, + content={"user_type": user_type}, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) + self.assertEqual(user_type, channel.json_body["user_type"]) # Get user channel = self.make_request( @@ -3003,30 +3128,44 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) + self.assertEqual(user_type, channel.json_body["user_type"]) + + def test_set_user_type(self) -> None: + """ + Test changing user type. + """ + + # Set to support type + self.set_user_type(UserTypes.SUPPORT) # Change back to a regular user - channel = self.make_request( - "PUT", - self.url_other_user, - access_token=self.admin_user_tok, - content={"user_type": None}, - ) + self.set_user_type(None) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertIsNone(channel.json_body["user_type"]) + @override_config({"user_types": {"extra_user_types": ["extra1", "extra2"]}}) + def test_set_user_type_with_extras(self) -> None: + """ + Test changing user type with extra_user_types configured. + """ - # Get user + # Check that we can still set to support type + self.set_user_type(UserTypes.SUPPORT) + + # Check that we can set to an extra user type + self.set_user_type("extra2") + + # Change back to a regular user + self.set_user_type(None) + + # Try setting to invalid type channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"user_type": "extra3"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertIsNone(channel.json_body["user_type"]) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Invalid user type", channel.json_body["error"]) def test_accidental_deactivation_prevention(self) -> None: """