diff options
-rw-r--r-- | changelog.d/11048.misc | 1 | ||||
-rw-r--r-- | tests/rest/admin/test_user.py | 401 |
2 files changed, 147 insertions, 255 deletions
diff --git a/changelog.d/11048.misc b/changelog.d/11048.misc new file mode 100644 index 0000000000..22d3c956f5 --- /dev/null +++ b/changelog.d/11048.misc @@ -0,0 +1 @@ +Simplify the user admin API tests. \ No newline at end of file diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 6ed9e42173..c9e2754b09 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -14,14 +14,13 @@ import hashlib import hmac -import json import os import urllib.parse from binascii import unhexlify from typing import List, Optional from unittest.mock import Mock, patch -from parameterized import parameterized +from parameterized import parameterized, parameterized_class import synapse.rest.admin from synapse.api.constants import UserTypes @@ -104,8 +103,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # 59 seconds self.reactor.advance(59) - body = json.dumps({"nonce": nonce}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) @@ -113,7 +112,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # 61 seconds self.reactor.advance(2) - channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -129,18 +128,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self): @@ -157,17 +154,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "user_type": UserTypes.SUPPORT, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -183,22 +178,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -218,9 +211,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Nonce check # - # Must be present - body = json.dumps({}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + # Must be an empty body present + channel = self.make_request("POST", self.url, {}) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) @@ -230,29 +222,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Must be present - body = json.dumps({"nonce": nonce()}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, {"nonce": nonce()}) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string - body = json.dumps({"nonce": nonce(), "username": 1234}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": 1234} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "abcd\u0000"} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a" * 1000} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) @@ -262,29 +253,29 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Must be present - body = json.dumps({"nonce": nonce(), "username": "a"}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a"} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string - body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a", "password": 1234} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long - body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a", "password": "A" * 1000} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) @@ -294,15 +285,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Invalid user_type - body = json.dumps( - { - "nonce": nonce(), - "username": "a", - "password": "1234", - "user_type": "invalid", - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid", + } + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) @@ -320,10 +309,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac} - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob1", + "password": "abc123", + "mac": want_mac, + } + + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) @@ -340,16 +333,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob2", - "displayname": None, - "password": "abc123", - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob2", + "displayname": None, + "password": "abc123", + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) @@ -366,22 +357,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob3", - "displayname": "", - "password": "abc123", - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob3", + "displayname": "", + "password": "abc123", + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) # set displayname channel = self.make_request("GET", self.url) @@ -391,16 +380,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob4", - "displayname": "Bob's Name", - "password": "abc123", - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob4", + "displayname": "Bob's Name", + "password": "abc123", + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) @@ -440,17 +427,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "user_type": UserTypes.SUPPORT, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -993,12 +978,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ If parameter `erase` is not boolean, return an error """ - body = json.dumps({"erase": "False"}) channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content={"erase": "False"}, access_token=self.admin_user_tok, ) @@ -2201,7 +2185,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2216,7 +2200,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -2359,7 +2343,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2374,7 +2358,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -3073,7 +3057,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self): @@ -3082,7 +3066,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): "POST", self.url, b"{}", access_token=self.other_user_tok ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) def test_send_event(self): """Test that sending event as a user works.""" @@ -3127,7 +3111,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( @@ -3160,7 +3144,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) def test_admin_logout_all(self): """Tests that the admin user calling `/logout/all` does expire the @@ -3181,7 +3165,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( @@ -3242,6 +3226,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, user=self.other_user, tok=puppet_token) +@parameterized_class( + ("url_prefix",), + [ + ("/_synapse/admin/v1/whois/%s",), + ("/_matrix/client/r0/admin/whois/%s",), + ], +) class WhoisRestTestCase(unittest.HomeserverTestCase): servlets = [ @@ -3254,21 +3245,14 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") - self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user) - self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote( - self.other_user - ) + self.url = self.url_prefix % self.other_user def test_no_auth(self): """ Try to get information of an user without authentication. """ - channel = self.make_request("GET", self.url1, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("GET", self.url2, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + channel = self.make_request("GET", self.url, b"{}") + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3280,38 +3264,21 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - self.url1, - access_token=other_user2_token, - ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "GET", - self.url2, + self.url, access_token=other_user2_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): """ Tests that a lookup for a user that is not a local returns a 400 """ - url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" - url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain" - - channel = self.make_request( - "GET", - url1, - access_token=self.admin_user_tok, - ) - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only whois a local user", channel.json_body["error"]) + url = self.url_prefix % "@unknown_person:unknown_domain" channel = self.make_request( "GET", - url2, + url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) @@ -3323,16 +3290,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( "GET", - self.url1, - access_token=self.admin_user_tok, - ) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual(self.other_user, channel.json_body["user_id"]) - self.assertIn("devices", channel.json_body) - - channel = self.make_request( - "GET", - self.url2, + self.url, access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -3347,16 +3305,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - self.url1, - access_token=other_user_token, - ) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual(self.other_user, channel.json_body["user_id"]) - self.assertIn("devices", channel.json_body) - - channel = self.make_request( - "GET", - self.url2, + self.url, access_token=other_user_token, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -3388,7 +3337,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request("POST", self.url) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3398,7 +3347,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): other_user_token = self.login("user", "pass") channel = self.make_request("POST", self.url, access_token=other_user_token) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): @@ -3447,84 +3396,41 @@ class RateLimitTestCase(unittest.HomeserverTestCase): % urllib.parse.quote(self.other_user) ) - def test_no_auth(self): + @parameterized.expand(["GET", "POST", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get information of a user without authentication. """ - channel = self.make_request("GET", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request(method, self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("DELETE", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + @parameterized.expand(["GET", "POST", "DELETE"]) + def test_requester_is_no_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ other_user_token = self.login("user", "pass") channel = self.make_request( - "GET", - self.url, - access_token=other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "POST", - self.url, - access_token=other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", + method, self.url, access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + @parameterized.expand(["GET", "POST", "DELETE"]) + def test_user_does_not_exist(self, method: str): """ Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - channel = self.make_request( - "POST", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", + method, url, access_token=self.admin_user_tok, ) @@ -3532,7 +3438,14 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + @parameterized.expand( + [ + ("GET", "Can only look up local users"), + ("POST", "Only local users can be ratelimited"), + ("DELETE", "Only local users can be ratelimited"), + ] + ) + def test_user_is_not_local(self, method: str, error_msg: str): """ Tests that a lookup for a user that is not a local returns a 400 """ @@ -3541,35 +3454,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only look up local users", channel.json_body["error"]) - - channel = self.make_request( - "POST", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual( - "Only local users can be ratelimited", channel.json_body["error"] - ) - - channel = self.make_request( - "DELETE", + method, url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual( - "Only local users can be ratelimited", channel.json_body["error"] - ) + self.assertEqual(error_msg, channel.json_body["error"]) def test_invalid_parameter(self): """ |