summary refs log tree commit diff
path: root/tests/rest/admin/test_user.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/admin/test_user.py')
-rw-r--r--tests/rest/admin/test_user.py408
1 files changed, 405 insertions, 3 deletions
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py

index 79a05b519b..a7b600a1d4 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -19,8 +19,7 @@ import json import urllib.parse from binascii import unhexlify from typing import List, Optional - -from mock import Mock +from unittest.mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes @@ -28,7 +27,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from tests import unittest from tests.server import FakeSite, make_request @@ -467,6 +466,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users" def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -634,6 +635,26 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # unkown order_by + channel = self.make_request( + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + def test_limit(self): """ Testing list of users with limit @@ -759,6 +780,103 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z") + user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y") + + # Modify user + self.get_success(self.store.set_user_deactivated_status(user1, True)) + self.get_success(self.store.set_shadow_banned(UserID.from_string(user1), True)) + + # Set avatar URL to all users, that no user has a NULL value to avoid + # different sort order between SQlite and PostreSQL + self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3")) + self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2")) + self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1")) + + # order by default (name) + self._order_test([self.admin_user, user1, user2], None) + self._order_test([self.admin_user, user1, user2], None, "f") + self._order_test([user2, user1, self.admin_user], None, "b") + + # order by name + self._order_test([self.admin_user, user1, user2], "name") + self._order_test([self.admin_user, user1, user2], "name", "f") + self._order_test([user2, user1, self.admin_user], "name", "b") + + # order by displayname + self._order_test([user2, user1, self.admin_user], "displayname") + self._order_test([user2, user1, self.admin_user], "displayname", "f") + self._order_test([self.admin_user, user1, user2], "displayname", "b") + + # order by is_guest + # like sort by ascending name, as no guest user here + self._order_test([self.admin_user, user1, user2], "is_guest") + self._order_test([self.admin_user, user1, user2], "is_guest", "f") + self._order_test([self.admin_user, user1, user2], "is_guest", "b") + + # order by admin + self._order_test([user1, user2, self.admin_user], "admin") + self._order_test([user1, user2, self.admin_user], "admin", "f") + self._order_test([self.admin_user, user1, user2], "admin", "b") + + # order by deactivated + self._order_test([self.admin_user, user2, user1], "deactivated") + self._order_test([self.admin_user, user2, user1], "deactivated", "f") + self._order_test([user1, self.admin_user, user2], "deactivated", "b") + + # order by user_type + # like sort by ascending name, as no special user type here + self._order_test([self.admin_user, user1, user2], "user_type") + self._order_test([self.admin_user, user1, user2], "user_type", "f") + self._order_test([self.admin_user, user1, user2], "is_guest", "b") + + # order by shadow_banned + self._order_test([self.admin_user, user2, user1], "shadow_banned") + self._order_test([self.admin_user, user2, user1], "shadow_banned", "f") + self._order_test([user1, self.admin_user, user2], "shadow_banned", "b") + + # order by avatar_url + self._order_test([self.admin_user, user2, user1], "avatar_url") + self._order_test([self.admin_user, user2, user1], "avatar_url", "f") + self._order_test([user1, user2, self.admin_user], "avatar_url", "b") + + def _order_test( + self, + expected_user_list: List[str], + order_by: Optional[str], + dir: Optional[str] = None, + ): + """Request the list of users in a certain order. Assert that order is what + we expect + Args: + expected_user_list: The list of user_id in the order we expect to get + back from the server + order_by: The type of ordering to give the server + dir: The direction of ordering to give the server + """ + + url = self.url + "?deactivated=true&" + if order_by is not None: + url += "order_by=%s&" % (order_by,) + if dir is not None and dir in ("b", "f"): + url += "dir=%s" % (dir,) + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_user_list)) + + returned_order = [row["name"] for row in channel.json_body["users"]] + self.assertEqual(expected_user_list, returned_order) + self._check_fields(channel.json_body["users"]) + def _check_fields(self, content: JsonDict): """Checks that the expected user attributes are present in content Args: @@ -2908,3 +3026,287 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # Ensure the user is shadow-banned (and the cache was cleared). result = self.get_success(self.store.get_user_by_access_token(other_user_token)) self.assertTrue(result.shadow_banned) + + +class RateLimitTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = ( + "/_synapse/admin/v1/users/%s/override_ratelimit" + % urllib.parse.quote(self.other_user) + ) + + def test_no_auth(self): + """ + Try to get information of a user without authentication. + """ + channel = self.make_request("GET", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + channel = self.make_request("DELETE", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + channel = self.make_request( + "POST", + self.url, + access_token=other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + channel = self.make_request( + "DELETE", + self.url, + access_token=other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + channel = self.make_request( + "POST", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = ( + "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit" + ) + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + channel = self.make_request( + "POST", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + "Only local users can be ratelimited", channel.json_body["error"] + ) + + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + "Only local users can be ratelimited", channel.json_body["error"] + ) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + # messages_per_second is a string + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"messages_per_second": "string"}, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # messages_per_second is negative + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"messages_per_second": -1}, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # burst_count is a string + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"burst_count": "string"}, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # burst_count is negative + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"burst_count": -1}, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_return_zero_when_null(self): + """ + If values in database are `null` API should return an int `0` + """ + + self.get_success( + self.store.db_pool.simple_upsert( + table="ratelimit_override", + keyvalues={"user_id": self.other_user}, + values={ + "messages_per_second": None, + "burst_count": None, + }, + ) + ) + + # request status + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(0, channel.json_body["messages_per_second"]) + self.assertEqual(0, channel.json_body["burst_count"]) + + def test_success(self): + """ + Rate-limiting (set/update/delete) should succeed for an admin. + """ + # request status + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertNotIn("messages_per_second", channel.json_body) + self.assertNotIn("burst_count", channel.json_body) + + # set ratelimit + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"messages_per_second": 10, "burst_count": 11}, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(10, channel.json_body["messages_per_second"]) + self.assertEqual(11, channel.json_body["burst_count"]) + + # update ratelimit + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"messages_per_second": 20, "burst_count": 21}, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(20, channel.json_body["messages_per_second"]) + self.assertEqual(21, channel.json_body["burst_count"]) + + # request status + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(20, channel.json_body["messages_per_second"]) + self.assertEqual(21, channel.json_body["burst_count"]) + + # delete ratelimit + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertNotIn("messages_per_second", channel.json_body) + self.assertNotIn("burst_count", channel.json_body) + + # request status + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertNotIn("messages_per_second", channel.json_body) + self.assertNotIn("burst_count", channel.json_body)