From 35b1900f00b77e754efb909eae0a2f0c94e968cb Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 30 Nov 2021 10:53:54 +0100 Subject: Convert status codes to `HTTPStatus` in `tests.rest.admin` (#11455) --- tests/rest/admin/test_statistics.py | 102 +++++++++++++++++++++++++----------- 1 file changed, 71 insertions(+), 31 deletions(-) (limited to 'tests/rest/admin/test_statistics.py') diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index ece89a65ac..43d8ca032b 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +from http import HTTPStatus from typing import Any, Dict, List, Optional import synapse.rest.admin @@ -47,21 +47,29 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( "GET", self.url, - json.dumps({}), + {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -75,7 +83,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -85,7 +97,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit @@ -95,7 +111,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from_ts @@ -105,7 +125,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative until_ts @@ -115,7 +139,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # until_ts smaller from_ts @@ -125,7 +153,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # empty search term @@ -135,7 +167,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -145,7 +181,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_limit(self): @@ -160,7 +200,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -178,7 +218,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -196,7 +236,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["users"]), 10) @@ -218,7 +258,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -231,7 +271,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -244,7 +284,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -257,7 +297,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -274,7 +314,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["users"])) @@ -371,7 +411,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media @@ -381,7 +421,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) self._create_media(self.other_user_tok, 3) @@ -396,7 +436,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media until `ts2` and earlier @@ -405,7 +445,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) def test_search_term(self): @@ -417,7 +457,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) # filter user 1 and 10-19 by `user_id` @@ -426,7 +466,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 11) # filter on this user in `displayname` @@ -435,7 +475,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") self.assertEqual(channel.json_body["total"], 1) @@ -445,7 +485,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) def _create_users_with_media(self, number_users: int, media_per_user: int): @@ -471,7 +511,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): for _ in range(number_media): # Upload some media into the room self.helper.upload_media( - upload_resource, SMALL_PNG, tok=user_token, expect_code=200 + upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK ) def _check_fields(self, content: List[Dict[str, Any]]): @@ -505,7 +545,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["user_id"] for row in channel.json_body["users"]] -- cgit 1.5.1 From e5f426cd54609e7f05f8241d845e6e36c5f10d9a Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 3 Dec 2021 14:57:13 +0100 Subject: Add type hints to `synapse/tests/rest/admin` (#11501) --- changelog.d/11501.misc | 1 + mypy.ini | 3 - tests/rest/admin/test_background_updates.py | 21 +++--- tests/rest/admin/test_device.py | 55 ++++++++-------- tests/rest/admin/test_event_reports.py | 61 ++++++++++-------- tests/rest/admin/test_media.py | 69 ++++++++++---------- tests/rest/admin/test_registration_tokens.py | 77 +++++++++++----------- tests/rest/admin/test_room.py | 95 +++++++++++++++------------- tests/rest/admin/test_server_notice.py | 27 ++++---- tests/rest/admin/test_statistics.py | 40 ++++++------ tests/rest/admin/test_user.py | 36 +++++------ 11 files changed, 257 insertions(+), 228 deletions(-) create mode 100644 changelog.d/11501.misc (limited to 'tests/rest/admin/test_statistics.py') diff --git a/changelog.d/11501.misc b/changelog.d/11501.misc new file mode 100644 index 0000000000..40e01194df --- /dev/null +++ b/changelog.d/11501.misc @@ -0,0 +1 @@ +Add type hints to `synapse/tests/rest/admin`. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index fea71d154c..1caf807e85 100644 --- a/mypy.ini +++ b/mypy.ini @@ -86,9 +86,6 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/admin/test_admin.py - |tests/rest/admin/test_device.py - |tests/rest/admin/test_media.py - |tests/rest/admin/test_server_notice.py |tests/rest/admin/test_user.py |tests/rest/admin/test_username_available.py |tests/rest/client/test_account.py diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index a5423af652..4d152c0d66 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -16,11 +16,14 @@ from typing import Collection from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login from synapse.server import HomeServer from synapse.storage.background_updates import BackgroundUpdater +from synapse.util import Clock from tests import unittest @@ -31,7 +34,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -44,7 +47,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ("POST", "/_synapse/admin/v1/background_updates/start_job"), ] ) - def test_requester_is_no_admin(self, method: str, url: str): + def test_requester_is_no_admin(self, method: str, url: str) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -62,7 +65,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -90,7 +93,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - def _register_bg_update(self): + def _register_bg_update(self) -> None: "Adds a bg update but doesn't start it" async def _fake_update(progress, batch_size) -> int: @@ -112,7 +115,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) ) - def test_status_empty(self): + def test_status_empty(self) -> None: """Test the status API works.""" channel = self.make_request( @@ -127,7 +130,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): channel.json_body, {"current_updates": {}, "enabled": True} ) - def test_status_bg_update(self): + def test_status_bg_update(self) -> None: """Test the status API works with a background update.""" # Create a new background update @@ -162,7 +165,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): }, ) - def test_enabled(self): + def test_enabled(self) -> None: """Test the enabled API works.""" # Create a new background update @@ -299,7 +302,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ), ] ) - def test_start_backround_job(self, job_name: str, updates: Collection[str]): + def test_start_backround_job(self, job_name: str, updates: Collection[str]) -> None: """ Test that background updates add to database and be processed. @@ -341,7 +344,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) ) - def test_start_backround_job_twice(self): + def test_start_backround_job_twice(self) -> None: """Test that add a background update twice return an error.""" # add job to database diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index baff057c56..f7080bda87 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import urllib.parse from http import HTTPStatus from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -31,7 +34,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -48,7 +51,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_no_auth(self, method: str): + def test_no_auth(self, method: str) -> None: """ Try to get a device of an user without authentication. """ @@ -62,7 +65,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_requester_is_no_admin(self, method: str): + def test_requester_is_no_admin(self, method: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -80,7 +83,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_does_not_exist(self, method: str): + def test_user_does_not_exist(self, method: str) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -99,7 +102,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_is_not_local(self, method: str): + def test_user_is_not_local(self, method: str) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -117,7 +120,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_device(self): + def test_unknown_device(self) -> None: """ Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. """ @@ -151,7 +154,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): # Delete unknown device returns status HTTPStatus.OK self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_update_device_too_long_display_name(self): + def test_update_device_too_long_display_name(self) -> None: """ Update a device with a display name that is invalid (too long). """ @@ -189,7 +192,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_no_display_name(self): + def test_update_no_display_name(self) -> None: """ Tests that a update for a device without JSON returns a HTTPStatus.OK """ @@ -219,7 +222,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_display_name(self): + def test_update_display_name(self) -> None: """ Tests a normal successful update of display name """ @@ -243,7 +246,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) - def test_get_device(self): + def test_get_device(self) -> None: """ Tests that a normal lookup for a device is successfully """ @@ -262,7 +265,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertIn("last_seen_ip", channel.json_body) self.assertIn("last_seen_ts", channel.json_body) - def test_delete_device(self): + def test_delete_device(self) -> None: """ Tests that a remove of a device is successfully """ @@ -292,7 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -302,7 +305,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to list devices of an user without authentication. """ @@ -315,7 +318,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -334,7 +337,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -348,7 +351,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -363,7 +366,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_user_has_no_devices(self): + def test_user_has_no_devices(self) -> None: """ Tests that a normal lookup for devices is successfully if user has no devices @@ -380,7 +383,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["devices"])) - def test_get_devices(self): + def test_get_devices(self) -> None: """ Tests that a normal lookup for devices is successfully """ @@ -416,7 +419,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -428,7 +431,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete devices of an user without authentication. """ @@ -441,7 +444,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -460,7 +463,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -474,7 +477,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -489,7 +492,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_devices(self): + def test_unknown_devices(self) -> None: """ Tests that a remove of a device that does not exist returns HTTPStatus.OK. """ @@ -503,7 +506,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): # Delete unknown devices returns status HTTPStatus.OK self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_delete_devices(self): + def test_delete_devices(self) -> None: """ Tests that a remove of devices is successfully """ diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index a9c46ec62d..4f89f8b534 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from http import HTTPStatus +from typing import List + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, report_event, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -29,7 +34,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -70,7 +75,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/event_reports" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to get an event report without authentication. """ @@ -83,7 +88,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -101,7 +106,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self): + def test_default_success(self) -> None: """ Testing list of reported events """ @@ -118,7 +123,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit(self): + def test_limit(self) -> None: """ Testing list of reported events with limit """ @@ -135,7 +140,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["event_reports"]) - def test_from(self): + def test_from(self) -> None: """ Testing list of reported events with a defined starting point (from) """ @@ -152,7 +157,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit_and_from(self): + def test_limit_and_from(self) -> None: """ Testing list of reported events with a defined starting point and limit """ @@ -169,7 +174,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["event_reports"]), 10) self._check_fields(channel.json_body["event_reports"]) - def test_filter_room(self): + def test_filter_room(self) -> None: """ Testing list of reported events with a filter of room """ @@ -189,7 +194,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): for report in channel.json_body["event_reports"]: self.assertEqual(report["room_id"], self.room_id1) - def test_filter_user(self): + def test_filter_user(self) -> None: """ Testing list of reported events with a filter of user """ @@ -209,7 +214,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): for report in channel.json_body["event_reports"]: self.assertEqual(report["user_id"], self.other_user) - def test_filter_user_and_room(self): + def test_filter_user_and_room(self) -> None: """ Testing list of reported events with a filter of user and room """ @@ -230,7 +235,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(report["user_id"], self.other_user) self.assertEqual(report["room_id"], self.room_id1) - def test_valid_search_order(self): + def test_valid_search_order(self) -> None: """ Testing search order. Order by timestamps. """ @@ -271,7 +276,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) report += 1 - def test_invalid_search_order(self): + def test_invalid_search_order(self) -> None: """ Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST """ @@ -290,7 +295,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual("Unknown direction: bar", channel.json_body["error"]) - def test_limit_is_negative(self): + def test_limit_is_negative(self) -> None: """ Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST """ @@ -308,7 +313,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_from_is_negative(self): + def test_from_is_negative(self) -> None: """ Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST """ @@ -326,7 +331,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_next_token(self): + def test_next_token(self) -> None: """ Testing that `next_token` appears at the right place """ @@ -384,7 +389,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertNotIn("next_token", channel.json_body) - def _create_event_and_report(self, room_id, user_tok): + def _create_event_and_report(self, room_id: str, user_tok: str) -> None: """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -397,7 +402,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def _create_event_and_report_without_parameters(self, room_id, user_tok): + def _create_event_and_report_without_parameters( + self, room_id: str, user_tok: str + ) -> None: """Create and report an event, but omit reason and score""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -410,7 +417,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def _check_fields(self, content): + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in an event report""" for c in content: self.assertIn("id", c) @@ -433,7 +440,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -453,7 +460,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): # first created event report gets `id`=2 self.url = "/_synapse/admin/v1/event_reports/2" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to get event report without authentication. """ @@ -466,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -484,7 +491,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self): + def test_default_success(self) -> None: """ Testing get a reported event """ @@ -498,7 +505,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) - def test_invalid_report_id(self): + def test_invalid_report_id(self) -> None: """ Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST. """ @@ -557,7 +564,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_report_id_not_found(self): + def test_report_id_not_found(self) -> None: """ Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND. """ @@ -576,7 +583,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual("Event report not found", channel.json_body["error"]) - def _create_event_and_report(self, room_id, user_tok): + def _create_event_and_report(self, room_id: str, user_tok: str) -> None: """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -589,7 +596,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def _check_fields(self, content): + def _check_fields(self, content: JsonDict) -> None: """Checks that all attributes are present in a event report""" self.assertIn("id", content) self.assertIn("received_ts", content) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 6618279dd1..81e578fd26 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -12,16 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os from http import HTTPStatus from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -39,7 +42,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -48,7 +51,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.filepaths = MediaFilePaths(hs.config.media.media_store_path) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete media without authentication. """ @@ -63,7 +66,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -85,7 +88,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_does_not_exist(self): + def test_media_does_not_exist(self) -> None: """ Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -100,7 +103,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_media_is_not_local(self): + def test_media_is_not_local(self) -> None: """ Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -115,7 +118,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_delete_media(self): + def test_delete_media(self) -> None: """ Tests that delete a media is successfully """ @@ -208,7 +211,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -221,7 +224,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # Move clock up to somewhat realistic time self.reactor.advance(1000000000) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete media without authentication. """ @@ -235,7 +238,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -255,7 +258,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_is_not_local(self): + def test_media_is_not_local(self) -> None: """ Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST """ @@ -270,7 +273,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_missing_parameter(self): + def test_missing_parameter(self) -> None: """ If the parameter `before_ts` is missing, an error is returned. """ @@ -290,7 +293,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): "Missing integer query parameter 'before_ts'", channel.json_body["error"] ) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -363,7 +366,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_delete_media_never_accessed(self): + def test_delete_media_never_accessed(self) -> None: """ Tests that media deleted if it is older than `before_ts` and never accessed `last_access_ts` is `NULL` and `created_ts` < `before_ts` @@ -394,7 +397,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_date(self): + def test_keep_media_by_date(self) -> None: """ Tests that media is not deleted if it is newer than `before_ts` """ @@ -431,7 +434,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_size(self): + def test_keep_media_by_size(self) -> None: """ Tests that media is not deleted if its size is smaller than or equal to `size_gt` @@ -466,7 +469,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_user_avatar(self): + def test_keep_media_by_user_avatar(self) -> None: """ Tests that we do not delete media if is used as a user avatar Tests parameter `keep_profiles` @@ -510,7 +513,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_room_avatar(self): + def test_keep_media_by_room_avatar(self) -> None: """ Tests that we do not delete media if it is used as a room avatar Tests parameter `keep_profiles` @@ -555,7 +558,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def _create_media(self): + def _create_media(self) -> str: """ Create a media and return media_id and server_and_media_id """ @@ -577,7 +580,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): return server_and_media_id - def _access_media(self, server_and_media_id, expect_success=True): + def _access_media(self, server_and_media_id, expect_success=True) -> None: """ Try to access a media and check the result """ @@ -627,7 +630,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() self.server_name = hs.hostname @@ -652,7 +655,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/media/%s/%s/%s" @parameterized.expand(["quarantine", "unquarantine"]) - def test_no_auth(self, action: str): + def test_no_auth(self, action: str) -> None: """ Try to protect media without authentication. """ @@ -671,7 +674,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["quarantine", "unquarantine"]) - def test_requester_is_no_admin(self, action: str): + def test_requester_is_no_admin(self, action: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -691,7 +694,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_quarantine_media(self): + def test_quarantine_media(self) -> None: """ Tests that quarantining and remove from quarantine a media is successfully """ @@ -725,7 +728,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): media_info = self.get_success(self.store.get_local_media(self.media_id)) self.assertFalse(media_info["quarantined_by"]) - def test_quarantine_protected_media(self): + def test_quarantine_protected_media(self) -> None: """ Tests that quarantining from protected media fails """ @@ -760,7 +763,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() @@ -784,7 +787,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/media/%s/%s" @parameterized.expand(["protect", "unprotect"]) - def test_no_auth(self, action: str): + def test_no_auth(self, action: str) -> None: """ Try to protect media without authentication. """ @@ -799,7 +802,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["protect", "unprotect"]) - def test_requester_is_no_admin(self, action: str): + def test_requester_is_no_admin(self, action: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -819,7 +822,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_protect_media(self): + def test_protect_media(self) -> None: """ Tests that protect and unprotect a media is successfully """ @@ -864,7 +867,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -874,7 +877,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.url = "/_synapse/admin/v1/purge_media_cache" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete media without authentication. """ @@ -888,7 +891,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self): + def test_requester_is_not_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -908,7 +911,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 63087955f2..350a62dda6 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import random import string from http import HTTPStatus +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -29,7 +32,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -39,7 +42,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/registration_tokens" - def _new_token(self, **kwargs): + def _new_token(self, **kwargs) -> str: """Helper function to create a token.""" token = kwargs.get( "token", @@ -61,7 +64,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # CREATION - def test_create_no_auth(self): + def test_create_no_auth(self) -> None: """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) self.assertEqual( @@ -71,7 +74,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_create_requester_not_admin(self): + def test_create_requester_not_admin(self) -> None: """Try to create a token while not an admin.""" channel = self.make_request( "POST", @@ -86,7 +89,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_create_using_defaults(self): + def test_create_using_defaults(self) -> None: """Create a token using all the defaults.""" channel = self.make_request( "POST", @@ -102,7 +105,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_specifying_fields(self): + def test_create_specifying_fields(self) -> None: """Create a token specifying the value of all fields.""" # As many of the allowed characters as possible with length <= 64 token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-" @@ -126,7 +129,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_with_null_value(self): + def test_create_with_null_value(self) -> None: """Create a token specifying unlimited uses and no expiry.""" data = { "uses_allowed": None, @@ -147,7 +150,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_token_too_long(self): + def test_create_token_too_long(self) -> None: """Check token longer than 64 chars is invalid.""" data = {"token": "a" * 65} @@ -165,7 +168,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_invalid_chars(self): + def test_create_token_invalid_chars(self) -> None: """Check you can't create token with invalid characters.""" data = { "token": "abc/def", @@ -185,7 +188,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_already_exists(self): + def test_create_token_already_exists(self) -> None: """Check you can't create token that already exists.""" data = { "token": "abcd", @@ -208,7 +211,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_unable_to_generate_token(self): + def test_create_unable_to_generate_token(self) -> None: """Check right error is raised when server can't generate unique token.""" # Create all possible single character tokens tokens = [] @@ -239,7 +242,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(500, channel.code, msg=channel.json_body) - def test_create_uses_allowed(self): + def test_create_uses_allowed(self) -> None: """Check you can only create a token with good values for uses_allowed.""" # Should work with 0 (token is invalid from the start) channel = self.make_request( @@ -279,7 +282,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_expiry_time(self): + def test_create_expiry_time(self) -> None: """Check you can't create a token with an invalid expiry_time.""" # Should fail with a time in the past channel = self.make_request( @@ -309,7 +312,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_length(self): + def test_create_length(self) -> None: """Check you can only generate a token with a valid length.""" # Should work with 64 channel = self.make_request( @@ -379,7 +382,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # UPDATING - def test_update_no_auth(self): + def test_update_no_auth(self) -> None: """Try to update a token without authentication.""" channel = self.make_request( "PUT", @@ -393,7 +396,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_update_requester_not_admin(self): + def test_update_requester_not_admin(self) -> None: """Try to update a token while not an admin.""" channel = self.make_request( "PUT", @@ -408,7 +411,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_update_non_existent(self): + def test_update_non_existent(self) -> None: """Try to update a token that doesn't exist.""" channel = self.make_request( "PUT", @@ -424,7 +427,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_update_uses_allowed(self): + def test_update_uses_allowed(self) -> None: """Test updating just uses_allowed.""" # Create new token using default values token = self._new_token() @@ -490,7 +493,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_expiry_time(self): + def test_update_expiry_time(self) -> None: """Test updating just expiry_time.""" # Create new token using default values token = self._new_token() @@ -547,7 +550,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_both(self): + def test_update_both(self) -> None: """Test updating both uses_allowed and expiry_time.""" # Create new token using default values token = self._new_token() @@ -569,7 +572,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) - def test_update_invalid_type(self): + def test_update_invalid_type(self) -> None: """Test using invalid types doesn't work.""" # Create new token using default values token = self._new_token() @@ -595,7 +598,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # DELETING - def test_delete_no_auth(self): + def test_delete_no_auth(self) -> None: """Try to delete a token without authentication.""" channel = self.make_request( "DELETE", @@ -609,7 +612,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_delete_requester_not_admin(self): + def test_delete_requester_not_admin(self) -> None: """Try to delete a token while not an admin.""" channel = self.make_request( "DELETE", @@ -624,7 +627,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_delete_non_existent(self): + def test_delete_non_existent(self) -> None: """Try to delete a token that doesn't exist.""" channel = self.make_request( "DELETE", @@ -640,7 +643,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_delete(self): + def test_delete(self) -> None: """Test deleting a token.""" # Create new token using default values token = self._new_token() @@ -656,7 +659,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # GETTING ONE - def test_get_no_auth(self): + def test_get_no_auth(self) -> None: """Try to get a token without authentication.""" channel = self.make_request( "GET", @@ -670,7 +673,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_get_requester_not_admin(self): + def test_get_requester_not_admin(self) -> None: """Try to get a token while not an admin.""" channel = self.make_request( "GET", @@ -685,7 +688,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_get_non_existent(self): + def test_get_non_existent(self) -> None: """Try to get a token that doesn't exist.""" channel = self.make_request( "GET", @@ -701,7 +704,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_get(self): + def test_get(self) -> None: """Test getting a token.""" # Create new token using default values token = self._new_token() @@ -722,7 +725,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # LISTING - def test_list_no_auth(self): + def test_list_no_auth(self) -> None: """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) self.assertEqual( @@ -732,7 +735,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_list_requester_not_admin(self): + def test_list_requester_not_admin(self) -> None: """Try to list tokens while not an admin.""" channel = self.make_request( "GET", @@ -747,7 +750,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_list_all(self): + def test_list_all(self) -> None: """Test listing all tokens.""" # Create new token using default values token = self._new_token() @@ -768,7 +771,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(token_info["pending"], 0) self.assertEqual(token_info["completed"], 0) - def test_list_invalid_query_parameter(self): + def test_list_invalid_query_parameter(self) -> None: """Test with `valid` query parameter not `true` or `false`.""" channel = self.make_request( "GET", @@ -783,7 +786,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): msg=channel.json_body, ) - def _test_list_query_parameter(self, valid: str): + def _test_list_query_parameter(self, valid: str) -> None: """Helper used to test both valid=true and valid=false.""" # Create 2 valid and 2 invalid tokens. now = self.hs.get_clock().time_msec() @@ -820,10 +823,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertIn(token_info_1["token"], tokens) self.assertIn(token_info_2["token"], tokens) - def test_list_valid(self): + def test_list_valid(self) -> None: """Test listing just valid tokens.""" self._test_list_query_parameter(valid="true") - def test_list_invalid(self): + def test_list_invalid(self) -> None: """Test listing just invalid tokens.""" self._test_list_query_parameter(valid="false") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 681f9173ef..d3858e460d 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import urllib.parse from http import HTTPStatus from typing import List, Optional @@ -19,11 +18,15 @@ from unittest.mock import Mock from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -39,7 +42,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): room.register_deprecated_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.event_creation_handler = hs.get_event_creation_handler() hs.config.consent.user_consent_version = "1" @@ -455,7 +458,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): room.register_deprecated_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.event_creation_handler = hs.get_event_creation_handler() hs.config.consent.user_consent_version = "1" @@ -1062,12 +1065,12 @@ class RoomTestCase(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Create user self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - def test_list_rooms(self): + def test_list_rooms(self) -> None: """Test that we can list rooms""" # Create 3 test rooms total_rooms = 3 @@ -1131,7 +1134,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # We shouldn't receive a next token here as there's no further rooms to show self.assertNotIn("next_batch", channel.json_body) - def test_list_rooms_pagination(self): + def test_list_rooms_pagination(self) -> None: """Test that we can get a full list of rooms through pagination""" # Create 5 test rooms total_rooms = 5 @@ -1213,7 +1216,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_correct_room_attributes(self): + def test_correct_room_attributes(self) -> None: """Test the correct attributes for a room are returned""" # Create a test room room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1294,7 +1297,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_alias, r["canonical_alias"]) - def test_room_list_sort_order(self): + def test_room_list_sort_order(self) -> None: """Test room list sort ordering. alphabetical name versus number of members, reversing the order, etc. """ @@ -1303,7 +1306,7 @@ class RoomTestCase(unittest.HomeserverTestCase): order_type: str, expected_room_list: List[str], reverse: bool = False, - ): + ) -> None: """Request the list of rooms in a certain order. Assert that order is what we expect @@ -1432,7 +1435,7 @@ class RoomTestCase(unittest.HomeserverTestCase): _order_test("state_events", [room_id_3, room_id_2, room_id_1]) _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) - def test_search_term(self): + def test_search_term(self) -> None: """Test that searching for a room works correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1461,7 +1464,7 @@ class RoomTestCase(unittest.HomeserverTestCase): expected_room_id: Optional[str], search_term: str, expected_http_code: int = HTTPStatus.OK, - ): + ) -> None: """Search for a room and check that the returned room's id is a match Args: @@ -1535,7 +1538,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Test search local part of alias _search_test(room_id_1, "alias1") - def test_search_term_non_ascii(self): + def test_search_term_non_ascii(self) -> None: """Test that searching for a room with non-ASCII characters works correctly""" # Create test room @@ -1562,7 +1565,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) - def test_single_room(self): + def test_single_room(self) -> None: """Test that a single room can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1613,7 +1616,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id_1, channel.json_body["room_id"]) - def test_single_room_devices(self): + def test_single_room_devices(self) -> None: """Test that `joined_local_devices` can be requested correctly""" room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1652,7 +1655,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) - def test_room_members(self): + def test_room_members(self) -> None: """Test that room members can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1700,7 +1703,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["total"], 3) - def test_room_state(self): + def test_room_state(self) -> None: """Test that room state can be requested correctly""" # Create two test rooms room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1717,7 +1720,9 @@ class RoomTestCase(unittest.HomeserverTestCase): # the create_room already does the right thing, so no need to verify that we got # the state events it created. - def _set_canonical_alias(self, room_id: str, test_alias: str, admin_user_tok: str): + def _set_canonical_alias( + self, room_id: str, test_alias: str, admin_user_tok: str + ) -> None: # Create a new alias to this room url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) channel = self.make_request( @@ -1752,7 +1757,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1767,7 +1772,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): ) self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -1782,7 +1787,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If a parameter is missing, return an error """ @@ -1797,7 +1802,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - def test_local_user_does_not_exist(self): + def test_local_user_does_not_exist(self) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -1812,7 +1817,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_remote_user(self): + def test_remote_user(self) -> None: """ Check that only local user can join rooms. """ @@ -1830,7 +1835,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_room_does_not_exist(self): + def test_room_does_not_exist(self) -> None: """ Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. """ @@ -1846,7 +1851,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) - def test_room_is_not_valid(self): + def test_room_is_not_valid(self) -> None: """ Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. """ @@ -1865,7 +1870,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_join_public_room(self): + def test_join_public_room(self) -> None: """ Test joining a local user to a public room with "JoinRules.PUBLIC" """ @@ -1890,7 +1895,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_not_member(self): + def test_join_private_room_if_not_member(self) -> None: """ Test joining a local user to a private room with "JoinRules.INVITE" when server admin is not member of this room. @@ -1910,7 +1915,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_join_private_room_if_member(self): + def test_join_private_room_if_member(self) -> None: """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is member of this room. @@ -1961,7 +1966,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_owner(self): + def test_join_private_room_if_owner(self) -> None: """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is owner of this room. @@ -1991,7 +1996,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_context_as_non_admin(self): + def test_context_as_non_admin(self) -> None: """ Test that, without being admin, one cannot use the context admin API """ @@ -2025,7 +2030,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_context_as_admin(self): + def test_context_as_admin(self) -> None: """ Test that, as admin, we can find the context of an event without having joined the room. """ @@ -2081,7 +2086,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2098,7 +2103,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): self.public_room_id ) - def test_public_room(self): + def test_public_room(self) -> None: """Test that getting admin in a public room works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2123,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def test_private_room(self): + def test_private_room(self) -> None: """Test that getting admin in a private room works and we get invited.""" room_id = self.helper.create_room_as( self.creator, @@ -2151,7 +2156,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def test_other_user(self): + def test_other_user(self) -> None: """Test that giving admin in a public room works to a non-admin user works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2176,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.second_tok, ) - def test_not_enough_power(self): + def test_not_enough_power(self) -> None: """Test that we get a sensible error if there are no local room admins.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2216,7 +2221,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -2231,7 +2236,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/rooms/%s/block" @parameterized.expand([("PUT",), ("GET",)]) - def test_requester_is_no_admin(self, method: str): + def test_requester_is_no_admin(self, method: str) -> None: """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.""" channel = self.make_request( @@ -2245,7 +2250,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand([("PUT",), ("GET",)]) - def test_room_is_not_valid(self, method: str): + def test_room_is_not_valid(self, method: str) -> None: """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.""" channel = self.make_request( @@ -2261,7 +2266,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_block_is_not_valid(self): + def test_block_is_not_valid(self) -> None: """If parameter `block` is not valid, return an error.""" # `block` is not valid @@ -2296,7 +2301,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) - def test_block_room(self): + def test_block_room(self) -> None: """Test that block a room is successful.""" def _request_and_test_block_room(room_id: str) -> None: @@ -2320,7 +2325,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): # unknown remote room _request_and_test_block_room("!unknown:remote") - def test_block_room_twice(self): + def test_block_room_twice(self) -> None: """Test that block a room that is already blocked is successful.""" self._is_blocked(self.room_id, expect=False) @@ -2335,7 +2340,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertTrue(channel.json_body["block"]) self._is_blocked(self.room_id, expect=True) - def test_unblock_room(self): + def test_unblock_room(self) -> None: """Test that unblock a room is successful.""" def _request_and_test_unblock_room(room_id: str) -> None: @@ -2360,7 +2365,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): # unknown remote room _request_and_test_unblock_room("!unknown:remote") - def test_unblock_room_twice(self): + def test_unblock_room_twice(self) -> None: """Test that unblock a room that is not blocked is successful.""" self._block_room(self.room_id) @@ -2375,7 +2380,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body["block"]) self._is_blocked(self.room_id, expect=False) - def test_get_blocked_room(self): + def test_get_blocked_room(self) -> None: """Test get status of a blocked room""" def _request_blocked_room(room_id: str) -> None: @@ -2399,7 +2404,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): # unknown remote room _request_blocked_room("!unknown:remote") - def test_get_unblocked_room(self): + def test_get_unblocked_room(self) -> None: """Test get status of a unblocked room""" def _request_unblocked_room(room_id: str) -> None: diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 0b9da4c732..3c59f5f766 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from http import HTTPStatus from typing import List +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, room, sync +from synapse.server import HomeServer from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -34,7 +37,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -49,7 +52,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/send_server_notice" - def test_no_auth(self): + def test_no_auth(self) -> None: """Try to send a server notice without authentication.""" channel = self.make_request("POST", self.url) @@ -60,7 +63,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """If the user is not a server admin, an error is returned.""" channel = self.make_request( "POST", @@ -76,7 +79,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" channel = self.make_request( "POST", @@ -89,7 +92,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -109,7 +112,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """If parameters are invalid, an error is returned.""" # no content, no user @@ -157,7 +160,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"]) - def test_server_notice_disabled(self): + def test_server_notice_disabled(self) -> None: """Tests that server returns error if server notice is disabled""" channel = self.make_request( "POST", @@ -176,7 +179,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice(self): + def test_send_server_notice(self) -> None: """ Tests that sending two server notices is successfully, the server uses the same room and do not send messages twice. @@ -240,7 +243,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(messages[1]["sender"], "@notices:test") @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_leave_room(self): + def test_send_server_notice_leave_room(self) -> None: """ Tests that sending a server notices is successfully. The user leaves the room and the second message appears @@ -324,7 +327,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertNotEqual(first_room_id, second_room_id) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_delete_room(self): + def test_send_server_notice_delete_room(self) -> None: """ Tests that the user get server notice in a new room after the first server notice room was deleted. @@ -414,7 +417,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> RoomsForUser: + ) -> List[RoomsForUser]: """Check invite and room membership status of a user. Args diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 43d8ca032b..7cb8ec57ba 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -12,13 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from http import HTTPStatus -from typing import Any, Dict, List, Optional +from typing import List, Optional + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import SMALL_PNG @@ -30,7 +34,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -41,7 +45,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/statistics/users/media" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to list users without authentication. """ @@ -54,7 +58,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -72,7 +76,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -188,7 +192,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_limit(self): + def test_limit(self) -> None: """ Testing list of media with limit """ @@ -206,7 +210,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["users"]) - def test_from(self): + def test_from(self) -> None: """ Testing list of media with a defined starting point (from) """ @@ -224,7 +228,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["users"]) - def test_limit_and_from(self): + def test_limit_and_from(self) -> None: """ Testing list of media with a defined starting point and limit """ @@ -242,7 +246,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["users"]), 10) self._check_fields(channel.json_body["users"]) - def test_next_token(self): + def test_next_token(self) -> None: """ Testing that `next_token` appears at the right place """ @@ -302,7 +306,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_no_media(self): + def test_no_media(self) -> None: """ Tests that a normal lookup for statistics is successfully if users have no media created @@ -318,7 +322,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["users"])) - def test_order_by(self): + def test_order_by(self) -> None: """ Testing order list with parameter `order_by` """ @@ -396,7 +400,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): "b", ) - def test_from_until_ts(self): + def test_from_until_ts(self) -> None: """ Testing filter by time with parameters `from_ts` and `until_ts` """ @@ -448,7 +452,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) - def test_search_term(self): + def test_search_term(self) -> None: self._create_users_with_media(20, 1) # check without filter get all users @@ -488,7 +492,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) - def _create_users_with_media(self, number_users: int, media_per_user: int): + def _create_users_with_media(self, number_users: int, media_per_user: int) -> None: """ Create a number of users with a number of media Args: @@ -500,7 +504,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): user_tok = self.login("foo_user_%s" % i, "pass") self._create_media(user_tok, media_per_user) - def _create_media(self, user_token: str, number_media: int): + def _create_media(self, user_token: str, number_media: int) -> None: """ Create a number of media for a specific user Args: @@ -514,7 +518,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK ) - def _check_fields(self, content: List[Dict[str, Any]]): + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in content Args: content: List that is checked for content @@ -527,7 +531,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): def _order_test( self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None - ): + ) -> None: """Request the list of users in a certain order. Assert that order is what we expect Args: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 03aa689ace..4fedd5fd08 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -127,14 +127,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -153,7 +153,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, @@ -161,7 +161,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -177,14 +177,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -308,13 +308,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob1", "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -332,14 +332,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob2", "displayname": None, "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -356,14 +356,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob3", "displayname": "", "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -379,14 +379,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob4", "displayname": "Bob's Name", "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -426,7 +426,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, @@ -434,7 +434,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -870,7 +870,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(expected_user_list, returned_order) self._check_fields(channel.json_body["users"]) - def _check_fields(self, content: JsonDict): + def _check_fields(self, content: List[JsonDict]): """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -3235,7 +3235,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): return media_id - def _check_fields(self, content: JsonDict): + def _check_fields(self, content: List[JsonDict]): """Checks that the expected user attributes are present in content Args: content: List that is checked for content -- cgit 1.5.1 From 7ecaa3b976b04dc5b2c6786aa18845016b80dd01 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 8 Dec 2021 17:59:40 +0100 Subject: Clean up `synapse.rest.admin` (#11535) --- changelog.d/11535.misc | 1 + synapse/rest/admin/__init__.py | 4 +- synapse/rest/admin/background_updates.py | 16 +++---- synapse/rest/admin/devices.py | 20 ++++----- synapse/rest/admin/event_reports.py | 2 - synapse/rest/admin/federation.py | 2 +- synapse/rest/admin/groups.py | 2 +- synapse/rest/admin/media.py | 60 ++++++++----------------- synapse/rest/admin/registration_tokens.py | 3 -- synapse/rest/admin/rooms.py | 70 +++++++++-------------------- synapse/rest/admin/server_notice_servlet.py | 4 +- synapse/rest/admin/statistics.py | 22 ++++----- synapse/rest/admin/username_available.py | 2 +- synapse/rest/admin/users.py | 51 ++++++++++----------- tests/rest/admin/test_statistics.py | 2 +- 15 files changed, 96 insertions(+), 165 deletions(-) create mode 100644 changelog.d/11535.misc (limited to 'tests/rest/admin/test_statistics.py') diff --git a/changelog.d/11535.misc b/changelog.d/11535.misc new file mode 100644 index 0000000000..580ac354ab --- /dev/null +++ b/changelog.d/11535.misc @@ -0,0 +1 @@ +Clean up `synapse.rest.admin`. \ No newline at end of file diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c499afd4be..701c609c12 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -108,7 +108,7 @@ class VersionServlet(RestServlet): class PurgeHistoryRestServlet(RestServlet): PATTERNS = admin_patterns( - "/purge_history/(?P[^/]*)(/(?P[^/]+))?" + "/purge_history/(?P[^/]*)(/(?P[^/]*))?$" ) def __init__(self, hs: "HomeServer"): @@ -195,7 +195,7 @@ class PurgeHistoryRestServlet(RestServlet): class PurgeHistoryStatusRestServlet(RestServlet): - PATTERNS = admin_patterns("/purge_history_status/(?P[^/]+)") + PATTERNS = admin_patterns("/purge_history_status/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.pagination_handler = hs.get_pagination_handler() diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index 479672d4d5..6ec00ce0b9 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -22,7 +22,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) from synapse.http.site import SynapseRequest -from synapse.rest.admin._base import admin_patterns, assert_user_is_admin +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import JsonDict if TYPE_CHECKING: @@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet): return HTTPStatus.OK, {"enabled": enabled} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) @@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet): self._data_stores = hs.get_datastores() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) # We need to check that all configured databases have updates enabled. # (They *should* all be in sync.) @@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet): class BackgroundUpdateStartJobRestServlet(RestServlet): """Allows to start specific background updates""" - PATTERNS = admin_patterns("/background_updates/start_job") + PATTERNS = admin_patterns("/background_updates/start_job$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() self._store = hs.get_datastore() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_requester_is_admin(self._auth, request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["job_name"]) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index 2e5a6600d3..062a33d28d 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str, device_id: str @@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -71,7 +71,7 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -87,7 +87,7 @@ class DeviceRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -109,14 +109,10 @@ class DevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/devices$", "v2") def __init__(self, hs: "HomeServer"): - """ - Args: - hs: server - """ - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -124,7 +120,7 @@ class DevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) @@ -144,10 +140,10 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/delete_devices$", "v2") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, user_id: str @@ -155,7 +151,7 @@ class DeleteDevicesRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users") u = await self.store.get_user_by_id(target_user.to_string()) diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 5ee8b11110..38477f8ead 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 744687be35..50d88c9109 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet): 200 OK with details of a destination if success otherwise an error. """ - PATTERNS = admin_patterns("/federation/destinations/(?P[^/]+)$") + PATTERNS = admin_patterns("/federation/destinations/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py index a27110388f..cd697e180e 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class DeleteGroupAdminRestServlet(RestServlet): """Allows deleting of local groups""" - PATTERNS = admin_patterns("/delete_group/(?P[^/]*)") + PATTERNS = admin_patterns("/delete_group/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.group_server = hs.get_groups_server_handler() diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 9e23e2d8fc..7236e4027f 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -17,7 +17,7 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest @@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = [ - *admin_patterns("/room/(?P[^/]+)/media/quarantine$"), + *admin_patterns("/room/(?P[^/]*)/media/quarantine$"), # This path kept around for legacy reasons - *admin_patterns("/quarantine_media/(?P[^/]+)"), + *admin_patterns("/quarantine_media/(?P[^/]*)$"), ] def __init__(self, hs: "HomeServer"): @@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine$") + PATTERNS = admin_patterns("/user/(?P[^/]*)/media/quarantine$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/quarantine/(?P[^/]+)/(?P[^/]+)" + "/media/quarantine/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet): """ PATTERNS = admin_patterns( - "/media/unquarantine/(?P[^/]+)/(?P[^/]+)" + "/media/unquarantine/(?P[^/]*)/(?P[^/]*)$" ) def __init__(self, hs: "HomeServer"): @@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info( "Remove from quarantine local media by ID: %s/%s", server_name, media_id @@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet): class ProtectMediaByID(RestServlet): """Protect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/protect/(?P[^/]+)") + PATTERNS = admin_patterns("/media/protect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Protecting local media by ID: %s", media_id) @@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet): class UnprotectMediaByID(RestServlet): """Unprotect local media from being quarantined.""" - PATTERNS = admin_patterns("/media/unprotect/(?P[^/]+)") + PATTERNS = admin_patterns("/media/unprotect/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet): async def on_POST( self, request: SynapseRequest, media_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) logging.info("Unprotecting local media by ID: %s", media_id) @@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet): class ListMediaInRoom(RestServlet): """Lists all of the media in a given room.""" - PATTERNS = admin_patterns("/room/(?P[^/]+)/media$") + PATTERNS = admin_patterns("/room/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - is_admin = await self.auth.is_server_admin(requester.user) - if not is_admin: - raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") + await assert_requester_is_admin(self.auth, request) local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id) @@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet): class DeleteMediaByID(RestServlet): """Delete local media by a given ID. Removes it from this server.""" - PATTERNS = admin_patterns("/media/(?P[^/]+)/(?P[^/]+)") + PATTERNS = admin_patterns("/media/(?P[^/]*)/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P[^/]+)/delete$") + PATTERNS = admin_patterns("/media/(?P[^/]*)/delete$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet): media that exist given for this user """ - PATTERNS = admin_patterns("/users/(?P[^/]+)/media$") + PATTERNS = admin_patterns("/users/(?P[^/]*)/media$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") @@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet): request, "order_by", default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), + allowed_values=[sort_order.value for sort_order in MediaSortOrder], ) direction = parse_string( request, "dir", default="f", allowed_values=("f", "b") diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 891b98c088..04948b6408 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/new$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.clock = hs.get_clock() self.auth = hs.get_auth() self.store = hs.get_datastore() diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 829e86675a..17c6df1cc8 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet): If 'purge' is true, it will remove all traces of a room from the database. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet): class DeleteRoomStatusByRoomIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/delete_status$", "v2") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/delete_status$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet): class DeleteRoomStatusByDeleteIdRestServlet(RestServlet): """Get the status of the delete room background task.""" - PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/rooms/delete_status/(?P[^/]*)$", "v2") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() @@ -193,35 +193,17 @@ class ListRoomRestServlet(RestServlet): self.admin_handler = hs.get_admin_handler() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) # Extract query parameters start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) - order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value) - if order_by not in ( - RoomSortOrder.ALPHABETICAL.value, - RoomSortOrder.SIZE.value, - RoomSortOrder.NAME.value, - RoomSortOrder.CANONICAL_ALIAS.value, - RoomSortOrder.JOINED_MEMBERS.value, - RoomSortOrder.JOINED_LOCAL_MEMBERS.value, - RoomSortOrder.VERSION.value, - RoomSortOrder.CREATOR.value, - RoomSortOrder.ENCRYPTION.value, - RoomSortOrder.FEDERATABLE.value, - RoomSortOrder.PUBLIC.value, - RoomSortOrder.JOIN_RULES.value, - RoomSortOrder.GUEST_ACCESS.value, - RoomSortOrder.HISTORY_VISIBILITY.value, - RoomSortOrder.STATE_EVENTS.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) + order_by = parse_string( + request, + "order_by", + default=RoomSortOrder.NAME.value, + allowed_values=[sort_order.value for sort_order in RoomSortOrder], + ) search_term = parse_string(request, "search_term", encoding="utf-8") if search_term == "": @@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet): TODO: Add on_POST to allow room creation without joining the room """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)$") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() @@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet): Get members list of a room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/members") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/members$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet): Get full state within a room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/state") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/state$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) ret = await self.store.get_room(room_id) if not ret: @@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = admin_patterns("/join/(?P[^/]*)") + PATTERNS = admin_patterns("/join/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() + self.is_mine = hs.is_mine async def on_POST( self, request: SynapseRequest, room_identifier: str @@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): assert_params_in_dict(content, ["user_id"]) target_user = UserID.from_string(content["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "This endpoint can only be used with local users", @@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): } """ - PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/make_room_admin$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_creation_handler = hs.get_event_creation_handler() @@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): GET /_synapse/admin/v1/rooms//forward_extremities """ - PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/forward_extremities$") def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() async def on_DELETE( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet): async def on_GET( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_requester_is_admin(self.auth, request) room_id, _ = await self.resolve_room_id(room_identifier) @@ -793,7 +767,7 @@ class BlockRoomRestServlet(RestServlet): On GET: Get blocking status of room and user who has blocked this room. """ - PATTERNS = admin_patterns("/rooms/(?P[^/]+)/block$") + PATTERNS = admin_patterns("/rooms/(?P[^/]*)/block$") def __init__(self, hs: "HomeServer"): self._auth = hs.get_auth() diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index b295fb078b..15da9cd881 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.server_notices_manager = hs.get_server_notices_manager() self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) + self.is_mine = hs.is_mine def register(self, json_resource: HttpServer) -> None: PATTERN = "/send_server_notice" @@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet): ) target_user = UserID.from_string(body["user_id"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users" ) diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py index ca41fd45f2..7a6546372e 100644 --- a/synapse/rest/admin/statistics.py +++ b/synapse/rest/admin/statistics.py @@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet): PATTERNS = admin_patterns("/statistics/users/media$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -45,19 +44,16 @@ class UserMediaStatisticsRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) order_by = parse_string( - request, "order_by", default=UserSortOrder.USER_ID.value + request, + "order_by", + default=UserSortOrder.USER_ID.value, + allowed_values=( + UserSortOrder.MEDIA_LENGTH.value, + UserSortOrder.MEDIA_COUNT.value, + UserSortOrder.USER_ID.value, + UserSortOrder.DISPLAYNAME.value, + ), ) - if order_by not in ( - UserSortOrder.MEDIA_LENGTH.value, - UserSortOrder.MEDIA_COUNT.value, - UserSortOrder.USER_ID.value, - UserSortOrder.DISPLAYNAME.value, - ): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Unknown value for order_by: %s" % (order_by,), - errcode=Codes.INVALID_PARAM, - ) start = parse_integer(request, "from", default=0) if start < 0: diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py index 2bf1472967..5353dc3682 100644 --- a/synapse/rest/admin/username_available.py +++ b/synapse/rest/admin/username_available.py @@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/username_available") + PATTERNS = admin_patterns("/username_available$") def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2a60b602b1..db678da4cf 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet): """ def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() @@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet): class UserRestServletV2(RestServlet): - PATTERNS = admin_patterns("/users/(?P[^/]+)$", "v2") + PATTERNS = admin_patterns("/users/(?P[^/]*)$", "v2") """Get request to list user details. This needs user to have administrator access in Synapse. @@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet): nonce to the time it was generated, in int seconds. """ - PATTERNS = admin_patterns("/register") + PATTERNS = admin_patterns("/register$") NONCE_TIMEOUT = 60 def __init__(self, hs: "HomeServer"): @@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet): ] def __init__(self, hs: "HomeServer"): - self.hs = hs self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet): if target_user != auth_user: await assert_user_is_admin(self.auth, auth_user) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") ret = await self.admin_handler.get_whois(target_user) @@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): - PATTERNS = admin_patterns("/deactivate/(?P[^/]*)") + PATTERNS = admin_patterns("/deactivate/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self._deactivate_account_handler = hs.get_deactivate_account_handler() @@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet): PATTERNS = admin_patterns("/account_validity/validity$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() @@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet): 200 OK with empty object if success otherwise an error. """ - PATTERNS = admin_patterns("/reset_password/(?P[^/]*)") + PATTERNS = admin_patterns("/reset_password/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler() @@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet): 200 OK with json object {list[dict[str, Any]], count} or empty object. """ - PATTERNS = admin_patterns("/search_users/(?P[^/]*)") + PATTERNS = admin_patterns("/search_users/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, target_user_id: str @@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet): # if not is_admin and target_user != auth_user: # raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user") term = parse_string(request, "term", required=True) @@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/admin$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine = hs.is_mine async def on_GET( self, request: SynapseRequest, user_id: str @@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet): target_user = UserID.from_string(user_id) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet): assert_params_in_dict(body, ["admin"]) - if not self.hs.is_mine(target_user): + if not self.is_mine(target_user): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be admins of this homeserver", @@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet): Get room list of an user. """ - PATTERNS = admin_patterns("/users/(?P[^/]+)/joined_rooms$") + PATTERNS = admin_patterns("/users/(?P[^/]*)/joined_rooms$") def __init__(self, hs: "HomeServer"): self.is_mine = hs.is_mine @@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet): PATTERNS = admin_patterns("/users/(?P[^/]*)/login$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str @@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet): await assert_user_is_admin(self.auth, requester.user) auth_user = requester.user - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be logged in as" ) @@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet): {} """ - PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban") + PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_POST( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned" ) @@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet): } """ - PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit") + PATTERNS = admin_patterns("/users/(?P[^/]*)/override_ratelimit$") def __init__(self, hs: "HomeServer"): - self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") if not await self.store.get_user_by_id(user_id): @@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) @@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) - if not self.hs.is_mine_id(user_id): + if not self.is_mine_id(user_id): raise SynapseError( HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited" ) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 7cb8ec57ba..f6e85fdaad 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -92,7 +92,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): channel.code, msg=channel.json_body, ) - self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative from channel = self.make_request( -- cgit 1.5.1