summary refs log tree commit diff
path: root/tests/rest/admin
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/admin')
-rw-r--r--tests/rest/admin/test_room.py15
-rw-r--r--tests/rest/admin/test_user.py287
2 files changed, 275 insertions, 27 deletions
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index a0f32c5512..7c47aa7e0a 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.json_body["total"], 3)
 
+    def test_room_state(self):
+        """Test that room state can be requested correctly"""
+        # Create two test rooms
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("state", channel.json_body)
+        # testing that the state events match is painful and not done here. We assume that
+        # the create_room already does the right thing, so no need to verify that we got
+        # the state events it created.
+
 
 class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 04599c2fcf..ee05ee60bc 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
 from synapse.api.room_versions import RoomVersions
 from synapse.rest.client.v1 import login, logout, profile, room
 from synapse.rest.client.v2_alpha import devices, sync
+from synapse.types import JsonDict
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        self.user1 = self.register_user(
-            "user1", "pass1", admin=False, displayname="Name 1"
-        )
-        self.user2 = self.register_user(
-            "user2", "pass2", admin=False, displayname="Name 2"
-        )
-
     def test_no_auth(self):
         """
         Try to list users without authentication.
@@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         """
         If the user is not a server admin, an error is returned.
         """
+        self._create_users(1)
         other_user_token = self.login("user1", "pass1")
 
         channel = self.make_request("GET", self.url, access_token=other_user_token)
@@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         """
         List all users, including deactivated users.
         """
+        self._create_users(2)
+
         channel = self.make_request(
             "GET",
             self.url + "?deactivated=true",
@@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(3, channel.json_body["total"])
 
         # Check that all fields are available
-        for u in channel.json_body["users"]:
-            self.assertIn("name", u)
-            self.assertIn("is_guest", u)
-            self.assertIn("admin", u)
-            self.assertIn("user_type", u)
-            self.assertIn("deactivated", u)
-            self.assertIn("displayname", u)
-            self.assertIn("avatar_url", u)
+        self._check_fields(channel.json_body["users"])
 
     def test_search_term(self):
         """Test that searching for a users works correctly"""
@@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
 
             # Check that users were returned
             self.assertTrue("users" in channel.json_body)
+            self._check_fields(channel.json_body["users"])
             users = channel.json_body["users"]
 
             # Check that the expected number of users were returned
@@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase):
                 u = users[0]
                 self.assertEqual(expected_user_id, u["name"])
 
+        self._create_users(2)
+
+        user1 = "@user1:test"
+        user2 = "@user2:test"
+
         # Perform search tests
-        _search_test(self.user1, "er1")
-        _search_test(self.user1, "me 1")
+        _search_test(user1, "er1")
+        _search_test(user1, "me 1")
 
-        _search_test(self.user2, "er2")
-        _search_test(self.user2, "me 2")
+        _search_test(user2, "er2")
+        _search_test(user2, "me 2")
 
-        _search_test(self.user1, "er1", "user_id")
-        _search_test(self.user2, "er2", "user_id")
+        _search_test(user1, "er1", "user_id")
+        _search_test(user2, "er2", "user_id")
 
         # Test case insensitive
-        _search_test(self.user1, "ER1")
-        _search_test(self.user1, "NAME 1")
+        _search_test(user1, "ER1")
+        _search_test(user1, "NAME 1")
 
-        _search_test(self.user2, "ER2")
-        _search_test(self.user2, "NAME 2")
+        _search_test(user2, "ER2")
+        _search_test(user2, "NAME 2")
 
-        _search_test(self.user1, "ER1", "user_id")
-        _search_test(self.user2, "ER2", "user_id")
+        _search_test(user1, "ER1", "user_id")
+        _search_test(user2, "ER2", "user_id")
 
         _search_test(None, "foo")
         _search_test(None, "bar")
@@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         _search_test(None, "foo", "user_id")
         _search_test(None, "bar", "user_id")
 
+    def test_invalid_parameter(self):
+        """
+        If parameters are invalid, an error is returned.
+        """
+
+        # negative limit
+        channel = self.make_request(
+            "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # negative from
+        channel = self.make_request(
+            "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # invalid guests
+        channel = self.make_request(
+            "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+        # invalid deactivated
+        channel = self.make_request(
+            "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+    def test_limit(self):
+        """
+        Testing list of users with limit
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 5)
+        self.assertEqual(channel.json_body["next_token"], "5")
+        self._check_fields(channel.json_body["users"])
+
+    def test_from(self):
+        """
+        Testing list of users with a defined starting point (from)
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 15)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["users"])
+
+    def test_limit_and_from(self):
+        """
+        Testing list of users with a defined starting point and limit
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(channel.json_body["next_token"], "15")
+        self.assertEqual(len(channel.json_body["users"]), 10)
+        self._check_fields(channel.json_body["users"])
+
+    def test_next_token(self):
+        """
+        Testing that `next_token` appears at the right place
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        #  `next_token` does not appear
+        # Number of results is the number of entries
+        channel = self.make_request(
+            "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), number_users)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does not appear
+        # Number of max results is larger than the number of entries
+        channel = self.make_request(
+            "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), number_users)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does appear
+        # Number of max results is smaller than the number of entries
+        channel = self.make_request(
+            "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 19)
+        self.assertEqual(channel.json_body["next_token"], "19")
+
+        # Check
+        # Set `from` to value of `next_token` for request remaining entries
+        #  `next_token` does not appear
+        channel = self.make_request(
+            "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 1)
+        self.assertNotIn("next_token", channel.json_body)
+
+    def _check_fields(self, content: JsonDict):
+        """Checks that the expected user attributes are present in content
+        Args:
+            content: List that is checked for content
+        """
+        for u in content:
+            self.assertIn("name", u)
+            self.assertIn("is_guest", u)
+            self.assertIn("admin", u)
+            self.assertIn("user_type", u)
+            self.assertIn("deactivated", u)
+            self.assertIn("displayname", u)
+            self.assertIn("avatar_url", u)
+
+    def _create_users(self, number_users: int):
+        """
+        Create a number of users
+        Args:
+            number_users: Number of users to be created
+        """
+        for i in range(1, number_users + 1):
+            self.register_user(
+                "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
+            )
+
 
 class DeactivateAccountTestCase(unittest.HomeserverTestCase):
 
@@ -2211,3 +2380,67 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
+
+
+class ShadowBanRestTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+
+        self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
+            self.other_user
+        )
+
+    def test_no_auth(self):
+        """
+        Try to get information of an user without authentication.
+        """
+        channel = self.make_request("POST", self.url)
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_not_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        other_user_token = self.login("user", "pass")
+
+        channel = self.make_request("POST", self.url, access_token=other_user_token)
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that shadow-banning for a user that is not a local returns a 400
+        """
+        url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+
+    def test_success(self):
+        """
+        Shadow-banning should succeed for an admin.
+        """
+        # The user starts off as not shadow-banned.
+        other_user_token = self.login("user", "pass")
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertFalse(result.shadow_banned)
+
+        channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual({}, channel.json_body)
+
+        # Ensure the user is shadow-banned (and the cache was cleared).
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertTrue(result.shadow_banned)