summary refs log tree commit diff
path: root/tests/rest/admin/test_user.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/admin/test_user.py')
-rw-r--r--tests/rest/admin/test_user.py101
1 files changed, 98 insertions, 3 deletions
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index df62317e69..4f379a5e55 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -18,6 +18,7 @@ import hmac
 import json
 import urllib.parse
 from binascii import unhexlify
+from typing import Optional
 
 from mock import Mock
 
@@ -466,8 +467,12 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        self.register_user("user1", "pass1", admin=False)
-        self.register_user("user2", "pass2", admin=False)
+        self.user1 = self.register_user(
+            "user1", "pass1", admin=False, displayname="Name 1"
+        )
+        self.user2 = self.register_user(
+            "user2", "pass2", admin=False, displayname="Name 2"
+        )
 
     def test_no_auth(self):
         """
@@ -476,7 +481,20 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        other_user_token = self.login("user1", "pass1")
+
+        request, channel = self.make_request(
+            "GET", self.url, access_token=other_user_token,
+        )
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_all_users(self):
         """
@@ -493,6 +511,83 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(3, len(channel.json_body["users"]))
         self.assertEqual(3, channel.json_body["total"])
 
+        # Check that all fields are available
+        for u in channel.json_body["users"]:
+            self.assertIn("name", u)
+            self.assertIn("is_guest", u)
+            self.assertIn("admin", u)
+            self.assertIn("user_type", u)
+            self.assertIn("deactivated", u)
+            self.assertIn("displayname", u)
+            self.assertIn("avatar_url", u)
+
+    def test_search_term(self):
+        """Test that searching for a users works correctly"""
+
+        def _search_test(
+            expected_user_id: Optional[str],
+            search_term: str,
+            search_field: Optional[str] = "name",
+            expected_http_code: Optional[int] = 200,
+        ):
+            """Search for a user and check that the returned user's id is a match
+
+            Args:
+                expected_user_id: The user_id expected to be returned by the API. Set
+                    to None to expect zero results for the search
+                search_term: The term to search for user names with
+                search_field: Field which is to request: `name` or `user_id`
+                expected_http_code: The expected http code for the request
+            """
+            url = self.url + "?%s=%s" % (search_field, search_term,)
+            request, channel = self.make_request(
+                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            )
+            self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+            if expected_http_code != 200:
+                return
+
+            # Check that users were returned
+            self.assertTrue("users" in channel.json_body)
+            users = channel.json_body["users"]
+
+            # Check that the expected number of users were returned
+            expected_user_count = 1 if expected_user_id else 0
+            self.assertEqual(len(users), expected_user_count)
+            self.assertEqual(channel.json_body["total"], expected_user_count)
+
+            if expected_user_id:
+                # Check that the first returned user id is correct
+                u = users[0]
+                self.assertEqual(expected_user_id, u["name"])
+
+        # Perform search tests
+        _search_test(self.user1, "er1")
+        _search_test(self.user1, "me 1")
+
+        _search_test(self.user2, "er2")
+        _search_test(self.user2, "me 2")
+
+        _search_test(self.user1, "er1", "user_id")
+        _search_test(self.user2, "er2", "user_id")
+
+        # Test case insensitive
+        _search_test(self.user1, "ER1")
+        _search_test(self.user1, "NAME 1")
+
+        _search_test(self.user2, "ER2")
+        _search_test(self.user2, "NAME 2")
+
+        _search_test(self.user1, "ER1", "user_id")
+        _search_test(self.user2, "ER2", "user_id")
+
+        _search_test(None, "foo")
+        _search_test(None, "bar")
+
+        _search_test(None, "foo", "user_id")
+        _search_test(None, "bar", "user_id")
+
 
 class UserRestTestCase(unittest.HomeserverTestCase):