diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index f09f66da00..5f73dbdc4a 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -328,6 +328,61 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ }
+ }
+ )
+ def test_extra_user_type(self) -> None:
+ """
+ Check that the extra user type can be used when registering a user.
+ """
+
+ def nonce_mac(user_type: str) -> tuple[str, str]:
+ """
+ Get a nonce and the expected HMAC for that nonce.
+ """
+ channel = self.make_request("GET", self.url)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(
+ nonce.encode("ascii")
+ + b"\x00alice\x00abc123\x00notadmin\x00"
+ + user_type.encode("ascii")
+ )
+ want_mac_str = want_mac.hexdigest()
+
+ return nonce, want_mac_str
+
+ nonce, mac = nonce_mac("extra1")
+ # Valid user_type
+ body = {
+ "nonce": nonce,
+ "username": "alice",
+ "password": "abc123",
+ "user_type": "extra1",
+ "mac": mac,
+ }
+ channel = self.make_request("POST", self.url, body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ nonce, mac = nonce_mac("extra3")
+ # Invalid user_type
+ body = {
+ "nonce": nonce,
+ "username": "alice",
+ "password": "abc123",
+ "user_type": "extra3",
+ "mac": mac,
+ }
+ channel = self.make_request("POST", self.url, body)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Invalid user type", channel.json_body["error"])
+
def test_displayname(self) -> None:
"""
Test that displayname of new user is set
@@ -1186,6 +1241,80 @@ class UsersListTestCase(unittest.HomeserverTestCase):
not_user_types=["custom"],
)
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ }
+ }
+ )
+ def test_filter_not_user_types_with_extra(self) -> None:
+ """Tests that the endpoint handles the not_user_types param when extra_user_types are configured"""
+
+ regular_user_id = self.register_user("normalo", "secret")
+
+ extra1_user_id = self.register_user("extra1", "secret")
+ self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/" + urllib.parse.quote(extra1_user_id),
+ {"user_type": "extra1"},
+ access_token=self.admin_user_tok,
+ )
+
+ def test_user_type(
+ expected_user_ids: List[str], not_user_types: Optional[List[str]] = None
+ ) -> None:
+ """Runs a test for the not_user_types param
+ Args:
+ expected_user_ids: Ids of the users that are expected to be returned
+ not_user_types: List of values for the not_user_types param
+ """
+
+ user_type_query = ""
+
+ if not_user_types is not None:
+ user_type_query = "&".join(
+ [f"not_user_type={u}" for u in not_user_types]
+ )
+
+ test_url = f"{self.url}?{user_type_query}"
+ channel = self.make_request(
+ "GET",
+ test_url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code)
+ self.assertEqual(channel.json_body["total"], len(expected_user_ids))
+ self.assertEqual(
+ expected_user_ids,
+ [u["name"] for u in channel.json_body["users"]],
+ )
+
+ # Request without user_types → all users expected
+ test_user_type([self.admin_user, extra1_user_id, regular_user_id])
+
+ # Request and exclude extra1 user type
+ test_user_type(
+ [self.admin_user, regular_user_id],
+ not_user_types=["extra1"],
+ )
+
+ # Request and exclude extra1 and extra2 user types
+ test_user_type(
+ [self.admin_user, regular_user_id],
+ not_user_types=["extra1", "extra2"],
+ )
+
+ # Request and exclude empty user types → only expected the extra1 user
+ test_user_type([extra1_user_id], not_user_types=[""])
+
+ # Request and exclude an unregistered type → expect all users
+ test_user_type(
+ [self.admin_user, extra1_user_id, regular_user_id],
+ not_user_types=["extra3"],
+ )
+
def test_erasure_status(self) -> None:
# Create a new user.
user_id = self.register_user("eraseme", "eraseme")
@@ -2977,22 +3106,18 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
- def test_set_user_type(self) -> None:
- """
- Test changing user type.
- """
-
- # Set to support type
+ def set_user_type(self, user_type: Optional[str]) -> None:
+ # Set to user_type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content={"user_type": UserTypes.SUPPORT},
+ content={"user_type": user_type},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
+ self.assertEqual(user_type, channel.json_body["user_type"])
# Get user
channel = self.make_request(
@@ -3003,30 +3128,44 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
+ self.assertEqual(user_type, channel.json_body["user_type"])
+
+ def test_set_user_type(self) -> None:
+ """
+ Test changing user type.
+ """
+
+ # Set to support type
+ self.set_user_type(UserTypes.SUPPORT)
# Change back to a regular user
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"user_type": None},
- )
+ self.set_user_type(None)
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertIsNone(channel.json_body["user_type"])
+ @override_config({"user_types": {"extra_user_types": ["extra1", "extra2"]}})
+ def test_set_user_type_with_extras(self) -> None:
+ """
+ Test changing user type with extra_user_types configured.
+ """
- # Get user
+ # Check that we can still set to support type
+ self.set_user_type(UserTypes.SUPPORT)
+
+ # Check that we can set to an extra user type
+ self.set_user_type("extra2")
+
+ # Change back to a regular user
+ self.set_user_type(None)
+
+ # Try setting to invalid type
channel = self.make_request(
- "GET",
+ "PUT",
self.url_other_user,
access_token=self.admin_user_tok,
+ content={"user_type": "extra3"},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertIsNone(channel.json_body["user_type"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Invalid user type", channel.json_body["error"])
def test_accidental_deactivation_prevention(self) -> None:
"""
|