summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_admin.py31
-rw-r--r--tests/rest/admin/test_background_updates.py29
-rw-r--r--tests/rest/admin/test_device.py77
-rw-r--r--tests/rest/admin/test_event_reports.py142
-rw-r--r--tests/rest/admin/test_federation.py73
-rw-r--r--tests/rest/admin/test_media.py166
-rw-r--r--tests/rest/admin/test_registration_tokens.py201
-rw-r--r--tests/rest/admin/test_room.py266
-rw-r--r--tests/rest/admin/test_server_notice.py93
-rw-r--r--tests/rest/admin/test_statistics.py99
-rw-r--r--tests/rest/admin/test_user.py562
-rw-r--r--tests/rest/admin/test_username_available.py17
-rw-r--r--tests/rest/client/test_account.py109
-rw-r--r--tests/rest/client/test_directory.py59
-rw-r--r--tests/rest/client/test_filter.py14
-rw-r--r--tests/rest/client/test_identity.py7
-rw-r--r--tests/rest/client/test_login.py119
-rw-r--r--tests/rest/client/test_models.py53
-rw-r--r--tests/rest/client/test_password_policy.py31
-rw-r--r--tests/rest/client/test_profile.py18
-rw-r--r--tests/rest/client/test_redactions.py4
-rw-r--r--tests/rest/client/test_register.py199
-rw-r--r--tests/rest/client/test_relations.py21
-rw-r--r--tests/rest/client/test_report_event.py11
-rw-r--r--tests/rest/client/test_retention.py4
-rw-r--r--tests/rest/client/test_rooms.py936
-rw-r--r--tests/rest/client/test_shadow_banned.py13
-rw-r--r--tests/rest/client/test_sync.py59
-rw-r--r--tests/rest/client/test_third_party_rules.py38
-rw-r--r--tests/rest/client/test_upgrade_room.py83
-rw-r--r--tests/rest/client/utils.py57
-rw-r--r--tests/rest/media/v1/test_html_preview.py58
-rw-r--r--tests/rest/media/v1/test_media_storage.py160
-rw-r--r--tests/rest/test_health.py4
-rw-r--r--tests/rest/test_well_known.py32
35 files changed, 2404 insertions, 1441 deletions
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py

index 82ac5991e6..a8f6436836 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py
@@ -13,7 +13,6 @@ # limitations under the License. import urllib.parse -from http import HTTPStatus from parameterized import parameterized @@ -42,7 +41,7 @@ class VersionTestCase(unittest.HomeserverTestCase): def test_version_string(self) -> None: channel = self.make_request("GET", self.url, shorthand=False) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( {"server_version", "python_version"}, set(channel.json_body.keys()) ) @@ -79,10 +78,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Should be quarantined self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s" + "Expected to receive a 404 on accessing quarantined media: %s" % server_and_media_id ), ) @@ -107,7 +106,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Expect a forbidden error self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg="Expected forbidden on quarantining media as a non-admin", ) @@ -139,7 +138,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Should be successful - self.assertEqual(HTTPStatus.OK, channel.code) + self.assertEqual(200, channel.code) # Quarantine the media url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( @@ -152,7 +151,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Attempt to access the media self._ensure_quarantined(admin_user_tok, server_name_and_media_id) @@ -209,7 +208,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" ) @@ -251,7 +250,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" ) @@ -285,7 +284,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) channel = self.make_request("POST", url, access_token=admin_user_tok) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( @@ -297,7 +296,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item" ) @@ -318,10 +317,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Shouldn't be quarantined self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s" + "Expected to receive a 200 on accessing not-quarantined media: %s" % server_and_media_id_2 ), ) @@ -350,7 +349,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): def test_purge_history(self) -> None: """ Simple test of purge history API. - Test only that is is possible to call, get status HTTPStatus.OK and purge_id. + Test only that is is possible to call, get status 200 and purge_id. """ channel = self.make_request( @@ -360,7 +359,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("purge_id", channel.json_body) purge_id = channel.json_body["purge_id"] @@ -371,5 +370,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 6cf56b1e35..d507a3af8d 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py
@@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import Collection from parameterized import parameterized @@ -51,7 +50,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) def test_requester_is_no_admin(self, method: str, url: str) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ self.register_user("user", "pass", admin=False) @@ -64,7 +63,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -81,7 +80,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # job_name invalid @@ -92,7 +91,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) def _register_bg_update(self) -> None: @@ -125,7 +124,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Background updates should be enabled, but none should be running. self.assertDictEqual( @@ -147,7 +146,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Background updates should be enabled, and one should be running. self.assertDictEqual( @@ -181,7 +180,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/enabled", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) # Disable the BG updates @@ -191,7 +190,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": False}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": False}) # Advance a bit and get the current status, note this will finish the in @@ -204,7 +203,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertDictEqual( channel.json_body, { @@ -231,7 +230,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # There should be no change from the previous /status response. self.assertDictEqual( @@ -259,7 +258,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": True}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) @@ -270,7 +269,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Background updates should be enabled and making progress. self.assertDictEqual( @@ -325,7 +324,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # test that each background update is waiting now for update in updates: @@ -365,4 +364,4 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index f7080bda87..d52aee8f92 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import urllib.parse -from http import HTTPStatus from parameterized import parameterized @@ -58,7 +57,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): channel = self.make_request(method, self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -76,7 +75,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -85,7 +84,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): @parameterized.expand(["GET", "PUT", "DELETE"]) def test_user_does_not_exist(self, method: str) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = ( "/_synapse/admin/v2/users/@unknown_person:test/devices/%s" @@ -98,13 +97,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) def test_user_is_not_local(self, method: str) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = ( "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s" @@ -117,12 +116,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) def test_unknown_device(self) -> None: """ - Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. + Tests that a lookup for a device that does not exist returns either 404 or 200. """ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( self.other_user @@ -134,7 +133,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) channel = self.make_request( @@ -143,7 +142,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) channel = self.make_request( "DELETE", @@ -151,8 +150,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # Delete unknown device returns status HTTPStatus.OK - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + # Delete unknown device returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) def test_update_device_too_long_display_name(self) -> None: """ @@ -179,7 +178,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content=update, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. @@ -189,12 +188,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) def test_update_no_display_name(self) -> None: """ - Tests that a update for a device without JSON returns a HTTPStatus.OK + Tests that a update for a device without JSON returns a 200 """ # Set iniital display name. update = {"display_name": "new display"} @@ -210,7 +209,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure the display name was not updated. channel = self.make_request( @@ -219,7 +218,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) def test_update_display_name(self) -> None: @@ -234,7 +233,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content={"display_name": "new displayname"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check new display_name channel = self.make_request( @@ -243,7 +242,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) def test_get_device(self) -> None: @@ -256,7 +255,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) # Check that all fields are available self.assertIn("user_id", channel.json_body) @@ -281,7 +280,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure that the number of devices is decreased res = self.get_success(self.handler.get_devices_by_user(self.other_user)) @@ -312,7 +311,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -331,7 +330,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -339,7 +338,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" channel = self.make_request( @@ -348,12 +347,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" @@ -363,7 +362,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) def test_user_has_no_devices(self) -> None: @@ -379,7 +378,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["devices"])) @@ -399,7 +398,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) @@ -438,7 +437,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -457,7 +456,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -465,7 +464,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" channel = self.make_request( @@ -474,12 +473,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" @@ -489,12 +488,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) def test_unknown_devices(self) -> None: """ - Tests that a remove of a device that does not exist returns HTTPStatus.OK. + Tests that a remove of a device that does not exist returns 200. """ channel = self.make_request( "POST", @@ -503,8 +502,8 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": ["unknown_device1", "unknown_device2"]}, ) - # Delete unknown devices returns status HTTPStatus.OK - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + # Delete unknown devices returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) def test_delete_devices(self) -> None: """ @@ -533,7 +532,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": device_ids}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) res = self.get_success(self.handler.get_devices_by_user(self.other_user)) self.assertEqual(0, len(res)) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 4f89f8b534..8a4e5c3f77 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py
@@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import List from twisted.test.proto_helpers import MemoryReactor @@ -81,16 +80,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -99,11 +94,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_default_success(self) -> None: @@ -117,7 +108,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -134,7 +125,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -151,7 +142,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -168,7 +159,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["event_reports"]), 10) @@ -185,7 +176,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -205,7 +196,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -225,7 +216,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertNotIn("next_token", channel.json_body) @@ -247,7 +238,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -265,7 +256,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -278,7 +269,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): def test_invalid_search_order(self) -> None: """ - Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST + Testing that a invalid search order returns a 400 """ channel = self.make_request( @@ -287,17 +278,13 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual("Unknown direction: bar", channel.json_body["error"]) def test_limit_is_negative(self) -> None: """ - Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST + Testing that a negative limit parameter returns a 400 """ channel = self.make_request( @@ -306,16 +293,12 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_from_is_negative(self) -> None: """ - Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST + Testing that a negative from parameter returns a 400 """ channel = self.make_request( @@ -324,11 +307,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self) -> None: @@ -344,7 +323,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -357,7 +336,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -370,7 +349,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -384,7 +363,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -400,7 +379,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): {"score": -100, "reason": "this makes me sad"}, access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def _create_event_and_report_without_parameters( self, room_id: str, user_tok: str @@ -415,7 +394,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): {}, access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in an event report""" @@ -431,6 +410,33 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertIn("score", c) self.assertIn("reason", c) + def test_count_correct_despite_table_deletions(self) -> None: + """ + Tests that the count matches the number of rows, even if rows in joined tables + are missing. + """ + + # Delete rows from room_stats_state for one of our rooms. + self.get_success( + self.hs.get_datastores().main.db_pool.simple_delete( + "room_stats_state", {"room_id": self.room_id1}, desc="_" + ) + ) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + # The 'total' field is 10 because only 10 reports will actually + # be retrievable since we deleted the rows in the room_stats_state + # table. + self.assertEqual(channel.json_body["total"], 10) + # This is consistent with the number of rows actually returned. + self.assertEqual(len(channel.json_body["event_reports"]), 10) + class EventReportDetailTestCase(unittest.HomeserverTestCase): servlets = [ @@ -466,16 +472,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -484,11 +486,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_default_success(self) -> None: @@ -502,12 +500,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) def test_invalid_report_id(self) -> None: """ - Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST. + Testing that an invalid `report_id` returns a 400. """ # `report_id` is negative @@ -517,11 +515,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -535,11 +529,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -553,11 +543,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -566,7 +552,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): def test_report_id_not_found(self) -> None: """ - Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND. + Testing that a not existing `report_id` returns a 404. """ channel = self.make_request( @@ -575,11 +561,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual("Event report not found", channel.json_body["error"]) @@ -594,7 +576,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): {"score": -100, "reason": "this makes me sad"}, access_token=user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def _check_fields(self, content: JsonDict) -> None: """Checks that all attributes are present in a event report""" diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 929bbdc37d..4c7864c629 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py
@@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import List, Optional from parameterized import parameterized @@ -64,7 +63,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -77,7 +76,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -87,7 +86,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # unkown order_by @@ -97,7 +96,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -107,7 +106,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid destination @@ -117,7 +116,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) # invalid destination @@ -127,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_limit(self) -> None: @@ -142,7 +141,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -160,7 +159,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -178,7 +177,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["destinations"]), 10) @@ -198,7 +197,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), number_destinations) self.assertNotIn("next_token", channel.json_body) @@ -211,7 +210,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), number_destinations) self.assertNotIn("next_token", channel.json_body) @@ -224,7 +223,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -238,7 +237,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_destinations) self.assertEqual(len(channel.json_body["destinations"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -255,7 +254,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_destinations, len(channel.json_body["destinations"])) self.assertEqual(number_destinations, channel.json_body["total"]) @@ -290,7 +289,7 @@ class FederationTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_destination_list)) returned_order = [ @@ -376,7 +375,7 @@ class FederationTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that destinations were returned self.assertTrue("destinations" in channel.json_body) @@ -418,7 +417,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("sub0.example.com", channel.json_body["destination"]) # Check that all fields are available @@ -435,7 +434,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("sub0.example.com", channel.json_body["destination"]) self.assertEqual(0, channel.json_body["retry_last_ts"]) self.assertEqual(0, channel.json_body["retry_interval"]) @@ -452,7 +451,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) retry_timings = self.get_success( self.store.get_destination_retry_timings("sub0.example.com") @@ -469,7 +468,7 @@ class FederationTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "The retry timing does not need to be reset for this destination.", channel.json_body["error"], @@ -561,7 +560,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -574,7 +573,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -584,7 +583,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -594,7 +593,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid destination @@ -604,7 +603,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_limit(self) -> None: @@ -619,7 +618,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), 3) self.assertEqual(channel.json_body["next_token"], "3") @@ -637,7 +636,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), 5) self.assertNotIn("next_token", channel.json_body) @@ -655,7 +654,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(channel.json_body["next_token"], "8") self.assertEqual(len(channel.json_body["rooms"]), 5) @@ -673,7 +672,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel_asc.code, msg=channel_asc.json_body) + self.assertEqual(200, channel_asc.code, msg=channel_asc.json_body) self.assertEqual(channel_asc.json_body["total"], number_rooms) self.assertEqual(number_rooms, len(channel_asc.json_body["rooms"])) self._check_fields(channel_asc.json_body["rooms"]) @@ -685,7 +684,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel_desc.code, msg=channel_desc.json_body) + self.assertEqual(200, channel_desc.code, msg=channel_desc.json_body) self.assertEqual(channel_desc.json_body["total"], number_rooms) self.assertEqual(number_rooms, len(channel_desc.json_body["rooms"])) self._check_fields(channel_desc.json_body["rooms"]) @@ -711,7 +710,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), number_rooms) self.assertNotIn("next_token", channel.json_body) @@ -724,7 +723,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), number_rooms) self.assertNotIn("next_token", channel.json_body) @@ -737,7 +736,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), 4) self.assertEqual(channel.json_body["next_token"], "4") @@ -751,7 +750,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(len(channel.json_body["rooms"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -767,7 +766,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_rooms) self.assertEqual(number_rooms, len(channel.json_body["rooms"])) self._check_fields(channel.json_body["rooms"]) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index e909e444ac..aadb31ca83 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py
@@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from http import HTTPStatus from parameterized import parameterized @@ -60,7 +59,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): channel = self.make_request("DELETE", url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -81,16 +80,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_media_does_not_exist(self) -> None: """ - Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a media that does not exist returns a 404 """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") @@ -100,12 +95,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_media_is_not_local(self) -> None: """ - Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a media that is not a local returns a 400 """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") @@ -115,7 +110,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) def test_delete_media(self) -> None: @@ -131,7 +126,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + expect_code=200, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -151,11 +146,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Should be successful self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing media: %s" - % server_and_media_id + "Expected to receive a 200 on accessing media: %s" % server_and_media_id ), ) @@ -172,7 +166,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -189,10 +183,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" + "Expected to receive a 404 on accessing deleted media: %s" % server_and_media_id ), ) @@ -231,11 +225,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -251,16 +241,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_media_is_not_local(self) -> None: """ - Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for media that is not local returns a 400 """ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" @@ -270,7 +256,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) def test_missing_parameter(self) -> None: @@ -283,11 +269,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( "Missing integer query parameter 'before_ts'", channel.json_body["error"] @@ -303,11 +285,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -320,11 +298,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " @@ -338,11 +312,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter size_gt must be a string representing a positive integer.", @@ -355,11 +325,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", @@ -388,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -413,7 +379,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -425,7 +391,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -449,7 +415,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -460,7 +426,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -485,7 +451,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): content={"avatar_url": "mxc://%s" % (server_and_media_id,)}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -493,7 +459,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -504,7 +470,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -530,7 +496,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): content={"url": "mxc://%s" % (server_and_media_id,)}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -538,7 +504,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -549,7 +515,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -569,7 +535,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + expect_code=200, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -602,10 +568,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): if expect_success: self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - "Expected to receive a HTTPStatus.OK on accessing media: %s" + "Expected to receive a 200 on accessing media: %s" % server_and_media_id ), ) @@ -613,10 +579,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertTrue(os.path.exists(local_path)) else: self.assertEqual( - HTTPStatus.NOT_FOUND, + 404, channel.code, msg=( - "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" + "Expected to receive a 404 on accessing deleted media: %s" % (server_and_media_id) ), ) @@ -648,7 +614,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + expect_code=200, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -668,11 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): b"{}", ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["quarantine", "unquarantine"]) @@ -689,11 +651,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_quarantine_media(self) -> None: @@ -712,7 +670,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -726,7 +684,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -753,7 +711,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) # verify that is not in quarantine @@ -785,7 +743,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=self.admin_user_tok, - expect_code=HTTPStatus.OK, + expect_code=200, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -801,11 +759,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url % (action, self.media_id), b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["protect", "unprotect"]) @@ -822,11 +776,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_protect_media(self) -> None: @@ -845,7 +795,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -859,7 +809,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -895,7 +845,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -914,11 +864,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -931,11 +877,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -948,11 +890,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 8354250ec2..8f8abc21c7 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py
@@ -13,7 +13,6 @@ # limitations under the License. import random import string -from http import HTTPStatus from typing import Optional from twisted.test.proto_helpers import MemoryReactor @@ -74,11 +73,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_create_no_auth(self) -> None: """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_create_requester_not_admin(self) -> None: @@ -89,11 +84,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_create_using_defaults(self) -> None: @@ -105,7 +96,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -129,7 +120,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) @@ -150,7 +141,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -168,11 +159,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_token_invalid_chars(self) -> None: @@ -188,11 +175,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_token_already_exists(self) -> None: @@ -207,7 +190,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body) + self.assertEqual(200, channel1.code, msg=channel1.json_body) channel2 = self.make_request( "POST", @@ -215,7 +198,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) + self.assertEqual(400, channel2.code, msg=channel2.json_body) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) def test_create_unable_to_generate_token(self) -> None: @@ -251,7 +234,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) # Should fail with negative integer @@ -262,7 +245,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual( - HTTPStatus.BAD_REQUEST, + 400, channel.code, msg=channel.json_body, ) @@ -275,11 +258,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_expiry_time(self) -> None: @@ -291,11 +270,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() - 10000}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -305,11 +280,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() + 1000000.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_length(self) -> None: @@ -321,7 +292,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 64}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 64) # Should fail with 0 @@ -331,11 +302,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 0}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -345,11 +312,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a float @@ -359,11 +322,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 8.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with 65 @@ -373,11 +332,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 65}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # UPDATING @@ -389,11 +344,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_update_requester_not_admin(self) -> None: @@ -404,11 +355,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_update_non_existent(self) -> None: @@ -420,11 +367,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_update_uses_allowed(self) -> None: @@ -439,7 +382,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertIsNone(channel.json_body["expiry_time"]) @@ -450,7 +393,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertIsNone(channel.json_body["expiry_time"]) @@ -461,7 +404,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -472,11 +415,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -486,11 +425,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_update_expiry_time(self) -> None: @@ -506,7 +441,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -517,7 +452,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": None}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -529,11 +464,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": past_time}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail a float @@ -543,11 +474,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time + 0.5}, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_update_both(self) -> None: @@ -568,7 +495,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) @@ -589,11 +516,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # DELETING @@ -605,11 +528,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_delete_requester_not_admin(self) -> None: @@ -620,11 +539,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_delete_non_existent(self) -> None: @@ -636,11 +551,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_delete(self) -> None: @@ -655,7 +566,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # GETTING ONE @@ -666,11 +577,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_get_requester_not_admin(self) -> None: @@ -682,7 +589,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -697,11 +604,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.NOT_FOUND, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_get(self) -> None: @@ -716,7 +619,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -728,11 +631,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_list_no_auth(self) -> None: """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_list_requester_not_admin(self) -> None: @@ -743,11 +642,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_list_all(self) -> None: @@ -762,7 +657,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 1) token_info = channel.json_body["registration_tokens"][0] self.assertEqual(token_info["token"], token) @@ -780,11 +675,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) def _test_list_query_parameter(self, valid: str) -> None: """Helper used to test both valid=true and valid=false.""" @@ -816,7 +707,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 2) token_info_1 = channel.json_body["registration_tokens"][0] token_info_2 = channel.json_body["registration_tokens"][1] diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ca6af9417b..9d71a97524 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import urllib.parse -from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock @@ -21,7 +20,7 @@ from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.errors import Codes from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room @@ -68,7 +67,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -78,7 +77,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self) -> None: @@ -94,11 +93,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_room_is_not_valid(self) -> None: """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ url = "/_synapse/admin/v1/rooms/%s" % "invalidroom" @@ -109,7 +108,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -127,7 +126,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -145,7 +144,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -163,7 +162,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self) -> None: @@ -178,7 +177,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self) -> None: @@ -202,7 +201,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -233,7 +232,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -265,7 +264,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -296,7 +295,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) # The room is now blocked. - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._is_blocked(room_id) def test_shutdown_room_consent(self) -> None: @@ -319,7 +318,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self.room_id, body="foo", tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + expect_code=403, ) # Test that room is not purged @@ -337,7 +336,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -366,7 +365,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): {"history_visibility": "world_readable"}, access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -383,7 +382,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -398,7 +397,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) + self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id: str, expect: bool = True) -> None: """Assert that the room is blocked or not""" @@ -494,7 +493,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) def test_requester_is_no_admin(self, method: str, url: str) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -504,7 +503,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self) -> None: @@ -522,7 +521,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -533,7 +532,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, len(channel.json_body["results"])) self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) @@ -546,7 +545,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) def test_room_is_not_valid(self, method: str, url: str) -> None: """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ channel = self.make_request( @@ -556,7 +555,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -574,7 +573,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -592,7 +591,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -610,7 +609,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self) -> None: @@ -625,7 +624,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_delete_expired_status(self) -> None: @@ -639,7 +638,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id1 = channel.json_body["delete_id"] @@ -654,7 +653,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id2 = channel.json_body["delete_id"] @@ -665,7 +664,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(2, len(channel.json_body["results"])) self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual("complete", channel.json_body["results"][1]["status"]) @@ -682,7 +681,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, len(channel.json_body["results"])) self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"]) @@ -696,7 +695,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_delete_same_room_twice(self) -> None: @@ -722,9 +721,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body - ) + self.assertEqual(400, second_channel.code, msg=second_channel.json_body) self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"]) self.assertEqual( f"History purge already in progress for {self.room_id}", @@ -733,7 +730,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): # get result of first call first_channel.await_result() - self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body) + self.assertEqual(200, first_channel.code, msg=first_channel.json_body) self.assertIn("delete_id", first_channel.json_body) # check status after finish the task @@ -764,7 +761,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -795,7 +792,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -827,7 +824,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -858,7 +855,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.room_id, body="foo", tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + expect_code=403, ) # Test that room is not purged @@ -876,7 +873,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -887,7 +884,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.url_status_by_room_id, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, len(channel.json_body["results"])) # Test that member has moved to new room @@ -914,7 +911,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): content={"history_visibility": "world_readable"}, access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -931,7 +928,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("delete_id", channel.json_body) delete_id = channel.json_body["delete_id"] @@ -942,7 +939,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.url_status_by_room_id, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, len(channel.json_body["results"])) # Test that member has moved to new room @@ -955,7 +952,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) + self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id: str, expect: bool = True) -> None: """Assert that the room is blocked or not""" @@ -1026,9 +1023,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.url_status_by_room_id, access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body - ) + self.assertEqual(200, channel_room_id.code, msg=channel_room_id.json_body) self.assertEqual(1, len(channel_room_id.json_body["results"])) self.assertEqual( delete_id, channel_room_id.json_body["results"][0]["delete_id"] @@ -1041,7 +1036,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual( - HTTPStatus.OK, + 200, channel_delete_id.code, msg=channel_delete_id.json_body, ) @@ -1085,7 +1080,9 @@ class RoomTestCase(unittest.HomeserverTestCase): room_ids = [] for _ in range(total_rooms): room_id = self.helper.create_room_as( - self.admin_user, tok=self.admin_user_tok + self.admin_user, + tok=self.admin_user_tok, + is_public=True, ) room_ids.append(room_id) @@ -1100,7 +1097,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) # Check request completed successfully - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -1124,12 +1121,14 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("version", r) self.assertIn("creator", r) self.assertIn("encryption", r) - self.assertIn("federatable", r) - self.assertIn("public", r) + self.assertIs(r["federatable"], True) + self.assertIs(r["public"], True) self.assertIn("join_rules", r) self.assertIn("guest_access", r) self.assertIn("history_visibility", r) self.assertIn("state_events", r) + self.assertIn("room_type", r) + self.assertIsNone(r["room_type"]) # Check that the correct number of total rooms was returned self.assertEqual(channel.json_body["total_rooms"], total_rooms) @@ -1184,7 +1183,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -1224,12 +1223,16 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_correct_room_attributes(self) -> None: """Test the correct attributes for a room are returned""" # Create a test room - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id = self.helper.create_room_as( + self.admin_user, + tok=self.admin_user_tok, + extra_content={"creation_content": {"type": RoomTypes.SPACE}}, + ) test_alias = "#test:test" test_room_name = "something" @@ -1247,7 +1250,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1279,7 +1282,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1306,6 +1309,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id, r["room_id"]) self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_alias, r["canonical_alias"]) + self.assertEqual(RoomTypes.SPACE, r["room_type"]) def test_room_list_sort_order(self) -> None: """Test room list sort ordering. alphabetical name versus number of members, @@ -1334,7 +1338,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1480,7 +1484,7 @@ class RoomTestCase(unittest.HomeserverTestCase): def _search_test( expected_room_id: Optional[str], search_term: str, - expected_http_code: int = HTTPStatus.OK, + expected_http_code: int = 200, ) -> None: """Search for a room and check that the returned room's id is a match @@ -1498,7 +1502,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != HTTPStatus.OK: + if expected_http_code != 200: return # Check that rooms were returned @@ -1541,7 +1545,7 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(None, "foo") _search_test(None, "bar") - _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST) + _search_test(None, "", expected_http_code=400) # Test that the whole room id returns the room _search_test(room_id_1, room_id_1) @@ -1578,15 +1582,19 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) - self.assertEqual("ะถ", channel.json_body.get("rooms")[0].get("name")) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id")) + self.assertEqual("ะถ", channel.json_body["rooms"][0].get("name")) def test_single_room(self) -> None: """Test that a single room can be requested correctly""" # Create two test rooms - room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_1 = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=True + ) + room_id_2 = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok, is_public=False + ) room_name_1 = "something" room_name_2 = "else" @@ -1611,7 +1619,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("room_id", channel.json_body) self.assertIn("name", channel.json_body) @@ -1630,8 +1638,12 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("guest_access", channel.json_body) self.assertIn("history_visibility", channel.json_body) self.assertIn("state_events", channel.json_body) + self.assertIn("room_type", channel.json_body) + self.assertIn("forgotten", channel.json_body) self.assertEqual(room_id_1, channel.json_body["room_id"]) + self.assertIs(True, channel.json_body["federatable"]) + self.assertIs(True, channel.json_body["public"]) def test_single_room_devices(self) -> None: """Test that `joined_local_devices` can be requested correctly""" @@ -1643,7 +1655,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["joined_local_devices"]) # Have another user join the room @@ -1657,7 +1669,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(2, channel.json_body["joined_local_devices"]) # leave room @@ -1669,7 +1681,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) def test_room_members(self) -> None: @@ -1700,7 +1712,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] @@ -1713,7 +1725,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] @@ -1731,7 +1743,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("state", channel.json_body) # testing that the state events match is painful and not done here. We assume that # the create_room already does the right thing, so no need to verify that we got @@ -1748,7 +1760,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1765,6 +1777,21 @@ class RoomTestCase(unittest.HomeserverTestCase): tok=admin_user_tok, ) + def test_get_joined_members_after_leave_room(self) -> None: + """Test that requesting room members after leaving the room raises a 403 error.""" + + # create the room + user = self.register_user("foo", "pass") + user_tok = self.login("foo", "pass") + room_id = self.helper.create_room_as(user, tok=user_tok) + self.helper.leave(room_id, user, tok=user_tok) + + # delete the rooms and get joined roomed membership + url = f"/_matrix/client/r0/rooms/{room_id}/joined_members" + channel = self.make_request("GET", url.encode("ascii"), access_token=user_tok) + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): @@ -1791,7 +1818,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( @@ -1801,7 +1828,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -1816,12 +1843,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_local_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ channel = self.make_request( @@ -1831,7 +1858,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_remote_user(self) -> None: @@ -1846,7 +1873,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], @@ -1854,7 +1881,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): def test_room_does_not_exist(self) -> None: """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + Check that unknown rooms/server return error 404. """ url = "/_synapse/admin/v1/join/!unknown:test" @@ -1865,12 +1892,15 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual("No known servers", channel.json_body["error"]) + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual( + "Can't join remote room because no servers that are in the room have been provided.", + channel.json_body["error"], + ) def test_room_is_not_valid(self) -> None: """ - Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. + Check that invalid room names, return an error 400. """ url = "/_synapse/admin/v1/join/invalidroom" @@ -1881,7 +1911,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], @@ -1899,7 +1929,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1909,7 +1939,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self) -> None: @@ -1929,7 +1959,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_join_private_room_if_member(self) -> None: @@ -1957,7 +1987,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1970,7 +2000,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1980,7 +2010,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self) -> None: @@ -2000,7 +2030,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -2010,7 +2040,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self) -> None: @@ -2044,7 +2074,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self) -> None: @@ -2074,7 +2104,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -2133,7 +2163,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -2160,7 +2190,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -2186,7 +2216,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -2220,11 +2250,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins. + # We expect this to fail with a 400 as there are no room admins. # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", @@ -2254,7 +2284,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): @parameterized.expand([("PUT",), ("GET",)]) def test_requester_is_no_admin(self, method: str) -> None: - """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.""" + """If the user is not a server admin, an error 403 is returned.""" channel = self.make_request( method, @@ -2263,12 +2293,12 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand([("PUT",), ("GET",)]) def test_room_is_not_valid(self, method: str) -> None: - """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.""" + """Check that invalid room names, return an error 400.""" channel = self.make_request( method, @@ -2277,7 +2307,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -2294,7 +2324,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) # `block` is not set @@ -2305,7 +2335,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # no content is send @@ -2315,7 +2345,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) def test_block_room(self) -> None: @@ -2329,7 +2359,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): content={"block": True}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["block"]) self._is_blocked(room_id, expect=True) @@ -2353,7 +2383,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): content={"block": True}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["block"]) self._is_blocked(self.room_id, expect=True) @@ -2369,7 +2399,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): content={"block": False}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body["block"]) self._is_blocked(room_id, expect=False) @@ -2393,7 +2423,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): content={"block": False}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body["block"]) self._is_blocked(self.room_id, expect=False) @@ -2408,7 +2438,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.url % room_id, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["block"]) self.assertEqual(self.other_user, channel.json_body["user_id"]) @@ -2432,7 +2462,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.url % room_id, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body["block"]) self.assertNotIn("user_id", channel.json_body) diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index dbcba2663c..a2f347f666 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import List from twisted.test.proto_helpers import MemoryReactor @@ -57,7 +56,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url) self.assertEqual( - HTTPStatus.UNAUTHORIZED, + 401, channel.code, msg=channel.json_body, ) @@ -72,7 +71,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.FORBIDDEN, + 403, channel.code, msg=channel.json_body, ) @@ -80,7 +79,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) def test_user_does_not_exist(self) -> None: - """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" + """Tests that a lookup for a user that does not exist returns a 404""" channel = self.make_request( "POST", self.url, @@ -88,13 +87,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": "@unknown_person:test", "content": ""}, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ channel = self.make_request( "POST", @@ -106,7 +105,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Server notices can only be sent to local users", channel.json_body["error"] ) @@ -122,7 +121,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) # no content @@ -133,7 +132,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # no body @@ -144,7 +143,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": ""}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'body' not in content", channel.json_body["error"]) @@ -156,10 +155,66 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": {"body": ""}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"]) + @override_config( + { + "server_notices": { + "system_mxid_localpart": "notices", + "system_mxid_avatar_url": "somthingwrong", + }, + "max_avatar_size": "10M", + } + ) + def test_invalid_avatar_url(self) -> None: + """If avatar url in homeserver.yaml is invalid and + "check avatar size and mime type" is set, an error is returned. + TODO: Should be checked when reading the configuration.""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg"}, + }, + ) + + self.assertEqual(500, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + @override_config( + { + "server_notices": { + "system_mxid_localpart": "notices", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + }, + "max_avatar_size": "10M", + } + ) + def test_displayname_is_set_avatar_is_none(self) -> None: + """ + Tests that sending a server notices is successfully, + if a display_name is set, avatar_url is `None` and + "check avatar size and mime type" is set. + """ + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + self._check_invite_and_join_status(self.other_user, 1, 0) + def test_server_notice_disabled(self) -> None: """Tests that server returns error if server notice is disabled""" channel = self.make_request( @@ -172,7 +227,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Server notices are not enabled on this server", channel.json_body["error"] @@ -197,7 +252,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -226,7 +281,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has no new invites or memberships self._check_invite_and_join_status(self.other_user, 0, 1) @@ -260,7 +315,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -301,7 +356,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -341,7 +396,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -388,7 +443,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -538,7 +593,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=token ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) # Get the messages room = channel.json_body["rooms"]["join"][room_id] diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 7cb8ec57ba..b60f16b914 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py
@@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor @@ -51,16 +50,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual( - HTTPStatus.UNAUTHORIZED, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: """ - If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. + If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( "GET", @@ -69,11 +64,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual( - HTTPStatus.FORBIDDEN, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -87,11 +78,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -101,11 +88,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit @@ -115,11 +98,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from_ts @@ -129,11 +108,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative until_ts @@ -143,11 +118,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # until_ts smaller from_ts @@ -157,11 +128,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # empty search term @@ -171,11 +138,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -185,11 +148,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_limit(self) -> None: @@ -204,7 +163,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -222,7 +181,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -240,7 +199,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["users"]), 10) @@ -262,7 +221,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -275,7 +234,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -288,7 +247,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -301,7 +260,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -318,7 +277,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["users"])) @@ -415,7 +374,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media @@ -425,7 +384,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) self._create_media(self.other_user_tok, 3) @@ -440,7 +399,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media until `ts2` and earlier @@ -449,7 +408,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) def test_search_term(self) -> None: @@ -461,7 +420,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) # filter user 1 and 10-19 by `user_id` @@ -470,7 +429,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 11) # filter on this user in `displayname` @@ -479,7 +438,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") self.assertEqual(channel.json_body["total"], 1) @@ -489,7 +448,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) def _create_users_with_media(self, number_users: int, media_per_user: int) -> None: @@ -515,7 +474,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): for _ in range(number_media): # Upload some media into the room self.helper.upload_media( - upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK + upload_resource, SMALL_PNG, tok=user_token, expect_code=200 ) def _check_fields(self, content: List[JsonDict]) -> None: @@ -549,7 +508,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["user_id"] for row in channel.json_body["users"]] diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0d44102237..1afd082707 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -1,4 +1,4 @@ -# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# Copyright 2018-2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ import hmac import os import urllib.parse from binascii import unhexlify -from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock, patch @@ -79,7 +78,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -111,7 +110,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -119,7 +118,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self) -> None: @@ -142,7 +141,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self) -> None: @@ -169,7 +168,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self) -> None: @@ -192,13 +191,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self) -> None: @@ -219,7 +218,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be an empty body present channel = self.make_request("POST", self.url, {}) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -229,28 +228,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present channel = self.make_request("POST", self.url, {"nonce": nonce()}) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -261,28 +260,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce(), "username": "a"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": "a", "password": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = {"nonce": nonce(), "username": "a", "password": "A" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -298,7 +297,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self) -> None: @@ -323,11 +322,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -347,11 +346,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -371,11 +370,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) # set displayname channel = self.make_request("GET", self.url) @@ -394,11 +393,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -442,7 +441,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -466,7 +465,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -478,7 +477,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self) -> None: @@ -494,7 +493,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) @@ -508,7 +507,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): expected_user_id: Optional[str], search_term: str, search_field: Optional[str] = "name", - expected_http_code: Optional[int] = HTTPStatus.OK, + expected_http_code: Optional[int] = 200, ) -> None: """Search for a user and check that the returned user's id is a match @@ -530,7 +529,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != HTTPStatus.OK: + if expected_http_code != 200: return # Check that users were returned @@ -591,7 +590,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -601,7 +600,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid guests @@ -611,7 +610,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid deactivated @@ -621,7 +620,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # unkown order_by @@ -631,7 +630,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -641,7 +640,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_limit(self) -> None: @@ -659,7 +658,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -680,7 +679,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -701,7 +700,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["users"]), 10) @@ -724,7 +723,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -737,7 +736,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -750,7 +749,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -764,7 +763,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -867,7 +866,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["name"] for row in channel.json_body["users"]] @@ -905,6 +904,96 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) +class UserDevicesTestCase(unittest.HomeserverTestCase): + """ + Tests user device management-related Admin APIs. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Set up an Admin user to query the Admin API with. + self.admin_user_id = self.register_user("admin", "pass", admin=True) + self.admin_user_token = self.login("admin", "pass") + + # Set up a test user to query the devices of. + self.other_user_device_id = "TESTDEVICEID" + self.other_user_device_display_name = "My Test Device" + self.other_user_client_ip = "1.2.3.4" + self.other_user_user_agent = "EquestriaTechnology/123.0" + + self.other_user_id = self.register_user("user", "pass", displayname="User1") + self.other_user_token = self.login( + "user", + "pass", + device_id=self.other_user_device_id, + additional_request_fields={ + "initial_device_display_name": self.other_user_device_display_name, + }, + ) + + # Have the "other user" make a request so that the "last_seen_*" fields are + # populated in the tests below. + channel = self.make_request( + "GET", + "/_matrix/client/v3/sync", + access_token=self.other_user_token, + client_ip=self.other_user_client_ip, + custom_headers=[ + ("User-Agent", self.other_user_user_agent), + ], + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + def test_list_user_devices(self) -> None: + """ + Tests that a user's devices and attributes are listed correctly via the Admin API. + """ + # Request all devices of "other user" + channel = self.make_request( + "GET", + f"/_synapse/admin/v2/users/{self.other_user_id}/devices", + access_token=self.admin_user_token, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Double-check we got the single device expected + user_devices = channel.json_body["devices"] + self.assertEqual(len(user_devices), 1) + self.assertEqual(channel.json_body["total"], 1) + + # Check that all the attributes of the device reported are as expected. + self._validate_attributes_of_device_response(user_devices[0]) + + # Request just a single device for "other user" by its ID + channel = self.make_request( + "GET", + f"/_synapse/admin/v2/users/{self.other_user_id}/devices/" + f"{self.other_user_device_id}", + access_token=self.admin_user_token, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check that all the attributes of the device reported are as expected. + self._validate_attributes_of_device_response(channel.json_body) + + def _validate_attributes_of_device_response(self, response: JsonDict) -> None: + # Check that all device expected attributes are present + self.assertEqual(response["user_id"], self.other_user_id) + self.assertEqual(response["device_id"], self.other_user_device_id) + self.assertEqual(response["display_name"], self.other_user_device_display_name) + self.assertEqual(response["last_seen_ip"], self.other_user_client_ip) + self.assertEqual(response["last_seen_user_agent"], self.other_user_user_agent) + self.assertIsInstance(response["last_seen_ts"], int) + self.assertGreater(response["last_seen_ts"], 0) + + class DeactivateAccountTestCase(unittest.HomeserverTestCase): servlets = [ @@ -941,7 +1030,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self) -> None: @@ -952,7 +1041,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", url, access_token=self.other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -962,12 +1051,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self) -> None: """ - Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that deactivation for a user that does not exist returns a 404 """ channel = self.make_request( @@ -976,7 +1065,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_erase_is_not_bool(self) -> None: @@ -991,18 +1080,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: """ - Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that deactivation for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" channel = self.make_request("POST", url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only deactivate local users", channel.json_body["error"]) def test_deactivate_user_erase_true(self) -> None: @@ -1017,7 +1106,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1032,7 +1121,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1041,7 +1130,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1066,7 +1155,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._is_erased("@user:test", True) def test_deactivate_user_erase_false(self) -> None: @@ -1081,7 +1170,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1096,7 +1185,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1105,7 +1194,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1135,7 +1224,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1150,7 +1239,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1159,7 +1248,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1220,7 +1309,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -1230,12 +1319,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ channel = self.make_request( @@ -1244,7 +1333,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) def test_invalid_parameter(self) -> None: @@ -1259,7 +1348,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"admin": "not_bool"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) # deactivated not bool @@ -1269,7 +1358,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": "not_bool"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not str @@ -1279,7 +1368,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": True}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not length @@ -1289,7 +1378,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": "x" * 513}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # user_type not valid @@ -1299,7 +1388,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"user_type": "new type"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # external_ids not valid @@ -1311,7 +1400,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} }, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1320,7 +1409,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # threepids not valid @@ -1330,7 +1419,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1339,7 +1428,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"address": "value"}}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_get_user(self) -> None: @@ -1352,7 +1441,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) @@ -1395,7 +1484,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1458,7 +1547,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1486,9 +1575,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # before limit of monthly active users is reached channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) - if channel.code != HTTPStatus.OK: + if channel.code != 200: raise HttpResponseException( - channel.code, channel.result["reason"], channel.json_body + channel.code, channel.result["reason"], channel.result["body"] ) # Set monthly active users to the limit @@ -1636,6 +1725,41 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(len(pushers), 0) + @override_config( + { + "email": { + "enable_notifs": True, + "notif_for_new_users": True, + "notif_from": "test@example.com", + }, + "public_baseurl": "https://example.com", + } + ) + def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None: + """ + Check that a new regular user is created successfully when they have a msisdn + threepid and email notif_for_new_users is set to True. + """ + url = self.url_prefix % "@bob:test" + + # Create user + body = { + "password": "abc123", + "threepids": [{"medium": "msisdn", "address": "1234567890"}], + } + + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body, + ) + + self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"]) + def test_set_password(self) -> None: """ Test setting a new password for another user. @@ -1649,7 +1773,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "hahaha"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) def test_set_displayname(self) -> None: @@ -1665,7 +1789,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "foobar"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1676,7 +1800,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1698,7 +1822,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1724,7 +1848,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1740,7 +1864,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1756,7 +1880,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": []}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1783,7 +1907,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1802,7 +1926,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1824,7 +1948,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # other user has this two threepids - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1843,7 +1967,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): url_first_user, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1872,7 +1996,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1904,7 +2028,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1923,7 +2047,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1942,7 +2066,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": []}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["external_ids"])) @@ -1971,7 +2095,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1997,7 +2121,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2029,7 +2153,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # must fail - self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body) + self.assertEqual(409, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("External id is already in use.", channel.json_body["error"]) @@ -2040,7 +2164,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2058,7 +2182,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2089,7 +2213,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -2104,7 +2228,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -2123,7 +2247,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -2153,7 +2277,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) @@ -2169,7 +2293,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "Foobar"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) @@ -2193,7 +2317,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) # Reactivate the user. channel = self.make_request( @@ -2202,7 +2326,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self._is_erased("@user:test", False) @@ -2226,7 +2350,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2236,7 +2360,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self._is_erased("@user:test", False) @@ -2260,7 +2384,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2270,7 +2394,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self._is_erased("@user:test", False) @@ -2291,7 +2415,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"admin": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2302,7 +2426,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2319,7 +2443,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": UserTypes.SUPPORT}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2330,7 +2454,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2342,7 +2466,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": None}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2353,7 +2477,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2383,7 +2507,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual(0, channel.json_body["deactivated"]) @@ -2396,7 +2520,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "deactivated": "false"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) # Check user is not deactivated channel = self.make_request( @@ -2405,7 +2529,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -2430,7 +2554,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": True}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) self._is_erased(user_id, False) d = self.store.mark_user_erased(user_id) @@ -2485,7 +2609,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -2500,7 +2624,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self) -> None: @@ -2514,7 +2638,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2530,7 +2654,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2546,7 +2670,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2567,7 +2691,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) @@ -2614,7 +2738,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) @@ -2643,7 +2767,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -2658,12 +2782,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" channel = self.make_request( @@ -2672,12 +2796,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" @@ -2687,7 +2811,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self) -> None: @@ -2702,7 +2826,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) # Register the pusher @@ -2734,7 +2858,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) for p in channel.json_body["pushers"]: @@ -2773,7 +2897,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """Try to list media of an user without authentication.""" channel = self.make_request(method, self.url, {}) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) @@ -2787,12 +2911,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_does_not_exist(self, method: str) -> None: - """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" + """Tests that a lookup for a user that does not exist returns a 404""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( method, @@ -2800,12 +2924,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_is_not_local(self, method: str) -> None: - """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST""" + """Tests that a lookup for a user that is not a local returns a 400""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( @@ -2814,7 +2938,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_limit_GET(self) -> None: @@ -2830,7 +2954,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -2849,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["deleted_media"]), 5) @@ -2866,7 +2990,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -2885,7 +3009,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 15) self.assertEqual(len(channel.json_body["deleted_media"]), 15) @@ -2902,7 +3026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) @@ -2921,7 +3045,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["deleted_media"]), 10) @@ -2935,7 +3059,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -2945,7 +3069,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit @@ -2955,7 +3079,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -2965,7 +3089,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self) -> None: @@ -2988,7 +3112,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -3001,7 +3125,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -3014,7 +3138,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -3028,7 +3152,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -3045,7 +3169,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) @@ -3060,7 +3184,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["deleted_media"])) @@ -3077,7 +3201,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["media"])) self.assertNotIn("next_token", channel.json_body) @@ -3103,7 +3227,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["deleted_media"])) self.assertCountEqual(channel.json_body["deleted_media"], media_ids) @@ -3248,7 +3372,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK + upload_resource, image_data, user_token, filename, expect_code=200 ) # Extract media ID from the response @@ -3266,10 +3390,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - HTTPStatus.OK, + 200, channel.code, msg=( - f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}" + f"Expected to receive a 200 on accessing media: {server_and_media_id}" ), ) @@ -3315,7 +3439,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_media_list)) returned_order = [row["media_id"] for row in channel.json_body["media"]] @@ -3351,14 +3475,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self) -> None: """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self) -> None: @@ -3367,7 +3491,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): "POST", self.url, b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) def test_send_event(self) -> None: """Test that sending event as a user works.""" @@ -3392,7 +3516,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -3404,21 +3528,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_user_logout_all(self) -> None: """Tests that the target user calling `/logout/all` does *not* expire @@ -3429,23 +3553,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) def test_admin_logout_all(self) -> None: """Tests that the admin user calling `/logout/all` does expire the @@ -3456,23 +3580,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3503,7 +3627,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): room_id, "com.example.test", tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + expect_code=403, ) # Login in as the user @@ -3524,7 +3648,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): room_id, user=self.other_user, tok=self.other_user_tok, - expect_code=HTTPStatus.FORBIDDEN, + expect_code=403, ) # Logging in as the other user and joining a room should work, even @@ -3559,7 +3683,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self) -> None: @@ -3574,12 +3698,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user2_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = self.url_prefix % "@unknown_person:unknown_domain" # type: ignore[attr-defined] @@ -3588,7 +3712,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) def test_get_whois_admin(self) -> None: @@ -3600,7 +3724,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3615,7 +3739,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user_token, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3645,7 +3769,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request(method, self.url) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) @@ -3656,18 +3780,18 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): other_user_token = self.login("user", "pass") channel = self.make_request(method, self.url, access_token=other_user_token) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) def test_user_is_not_local(self, method: str) -> None: """ - Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that shadow-banning for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" channel = self.make_request(method, url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) def test_success(self) -> None: """ @@ -3680,7 +3804,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.assertFalse(result.shadow_banned) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is shadow-banned (and the cache was cleared). @@ -3692,7 +3816,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, access_token=self.admin_user_tok ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is no longer shadow-banned (and the cache was cleared). @@ -3727,7 +3851,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) @@ -3743,13 +3867,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) def test_user_does_not_exist(self, method: str) -> None: """ - Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND + Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" @@ -3759,7 +3883,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand( @@ -3771,7 +3895,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) def test_user_is_not_local(self, method: str, error_msg: str) -> None: """ - Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST + Tests that a lookup for a user that is not a local returns a 400 """ url = ( "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit" @@ -3783,7 +3907,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(error_msg, channel.json_body["error"]) def test_invalid_parameter(self) -> None: @@ -3798,7 +3922,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": "string"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3809,7 +3933,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": -1}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3820,7 +3944,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": "string"}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3831,7 +3955,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": -1}, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self) -> None: @@ -3856,7 +3980,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3870,7 +3994,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3881,7 +4005,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3892,7 +4016,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3902,7 +4026,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3912,7 +4036,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3922,7 +4046,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3947,7 +4071,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): """Try to get information of a user without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self) -> None: @@ -3960,7 +4084,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self) -> None: @@ -3973,7 +4097,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self) -> None: @@ -3986,7 +4110,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_success(self) -> None: @@ -4007,7 +4131,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual( {"a": 1}, channel.json_body["account_data"]["global"]["m.global"] ) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index b21f6d4689..30f12f1bff 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py
@@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from http import HTTPStatus - from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -40,7 +37,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): if username == "allowed": return True raise SynapseError( - HTTPStatus.BAD_REQUEST, + 400, "User ID already taken.", errcode=Codes.USER_IN_USE, ) @@ -50,27 +47,23 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): def test_username_available(self) -> None: """ - The endpoint should return a HTTPStatus.OK response if the username does not exist + The endpoint should return a 200 response if the username does not exist """ url = "%s?username=%s" % (self.url, "allowed") channel = self.make_request("GET", url, access_token=self.admin_user_tok) - self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["available"]) def test_username_unavailable(self) -> None: """ - The endpoint should return a HTTPStatus.OK response if the username does not exist + The endpoint should return a 200 response if the username does not exist """ url = "%s?username=%s" % (self.url, "disallowed") channel = self.make_request("GET", url, access_token=self.admin_user_tok) - self.assertEqual( - HTTPStatus.BAD_REQUEST, - channel.code, - msg=channel.json_body, - ) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") self.assertEqual(channel.json_body["error"], "User ID already taken.") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a43a137273..c1a7fb2f8a 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py
@@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os import re from email.parser import Parser +from http import HTTPStatus from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock @@ -95,10 +95,8 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") - ) - self.assertEqual(channel.code, 403, channel.result) + channel = self.make_request("POST", "/_matrix/client/r0/login", body) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) def test_basic_password_reset(self) -> None: """Test basic password reset flow""" @@ -347,7 +345,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button @@ -362,7 +360,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -390,7 +388,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): new_password: str, session_id: str, client_secret: str, - expected_code: int = 200, + expected_code: int = HTTPStatus.OK, ) -> None: channel = self.make_request( "POST", @@ -479,20 +477,18 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertEqual(memberships[0].room_id, room_id, memberships) def deactivate(self, user_id: str, tok: str) -> None: - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, 200, channel.json_body) class WhoamiTestCase(unittest.HomeserverTestCase): @@ -645,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_no_at(self) -> None: self._request_token_invalid_email( "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, + expected_errcode=Codes.BAD_JSON, expected_error="Unable to parse email address", ) def test_add_email_two_at(self) -> None: self._request_token_invalid_email( "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, + expected_errcode=Codes.BAD_JSON, expected_error="Unable to parse email address", ) def test_add_email_bad_format(self) -> None: self._request_token_invalid_email( "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, + expected_errcode=Codes.BAD_JSON, expected_error="Unable to parse email address", ) @@ -715,7 +711,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -725,7 +723,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email(self) -> None: @@ -747,7 +745,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -756,7 +754,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email_if_disabled(self) -> None: @@ -781,7 +779,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) @@ -817,7 +817,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -827,7 +829,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_no_valid_token(self) -> None: @@ -852,7 +854,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -862,7 +866,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @override_config({"next_link_domain_whitelist": None}) @@ -872,7 +876,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/good/site", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -884,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="some-protocol://abcdefghijklmopqrstuvwxyz", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -895,7 +899,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="file:///host/path", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) @@ -907,28 +911,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link=None, - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.com/some/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.org/some/also/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://bad.example.org/some/bad/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": []}) @@ -940,7 +944,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) def _request_token( @@ -948,8 +952,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): email: str, client_secret: str, next_link: Optional[str] = None, - expect_code: int = 200, - ) -> str: + expect_code: int = HTTPStatus.OK, + ) -> Optional[str]: """Request a validation token to add an email address to a user's account Args: @@ -959,7 +963,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): expect_code: Expected return code of the call Returns: - The ID of the new threepid validation session + The ID of the new threepid validation session, or None if the response + did not contain a session ID. """ body = {"client_secret": client_secret, "email": email, "send_attempt": 1} if next_link: @@ -992,16 +997,18 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(expected_errcode, channel.json_body["errcode"]) - self.assertEqual(expected_error, channel.json_body["error"]) + self.assertIn(expected_error, channel.json_body["error"]) def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -1051,7 +1058,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -1060,7 +1067,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} @@ -1091,7 +1098,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that not providing any MXID raises an error.""" self._test_status( users=None, - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.MISSING_PARAM, ) @@ -1099,7 +1106,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that providing an invalid MXID raises an error.""" self._test_status( users=["bad:test"], - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.INVALID_PARAM, ) @@ -1285,7 +1292,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): def _test_status( self, users: Optional[List[str]], - expected_status_code: int = 200, + expected_status_code: int = HTTPStatus.OK, expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_failures: Optional[List[str]] = None, expected_errcode: Optional[str] = None, diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index aca03afd0e..7a88aa2cda 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py
@@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor +from synapse.appservice import ApplicationService from synapse.rest import admin from synapse.rest.client import directory, login, room from synapse.server import HomeServer @@ -96,8 +96,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # We use deliberately a localpart under the length threshold so # that we can make sure that the check is done on the whole alias. - data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -109,8 +108,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # Check with an alias of allowed length. There should already be # a test that ensures it works in test_register.py, but let's be # as cautious as possible here. - data = {"room_alias_name": random_string(5)} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(5)} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -129,6 +127,38 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + def test_deleting_alias_via_directory_appservice(self) -> None: + user_id = "@as:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + id="1234", + namespaces={"aliases": [{"regex": "#asns-*", "exclusive": True}]}, + sender=user_id, + ) + self.hs.get_datastores().main.services_cache.append(appservice) + + # Add an alias for the room, as the appservice + alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string() + request_data = {"room_id": self.room_id} + + channel = self.make_request( + "PUT", + f"/_matrix/client/r0/directory/room/{alias}", + request_data, + access_token=as_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # Then try to remove the alias, as the appservice + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=as_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + def test_deleting_nonexistant_alias(self) -> None: # Check that no alias exists alias = "#potato:test" @@ -159,8 +189,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.hs.hostname, ) - data = {"aliases": [self.random_alias(alias_length)]} - request_data = json.dumps(data) + request_data = {"aliases": [self.random_alias(alias_length)]} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok @@ -172,8 +201,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) -> str: alias = self.random_alias(alias_length) url = "/_matrix/client/r0/directory/room/%s" % alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok @@ -181,6 +209,19 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, expected_code, channel.result) return alias + def test_invalid_alias(self) -> None: + alias = "#potato" + channel = self.make_request( + "GET", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual( + channel.json_body["errcode"], "M_INVALID_PARAM", channel.json_body + ) + def random_alias(self, length: int) -> str: return RoomAlias(random_string(length), self.hs.hostname).to_string() diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 823e8ab8c4..afc8d641be 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py
@@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.EXAMPLE_FILTER_JSON, ) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.get_success( self.store.get_user_filter(user_localpart="apple", filter_id=0) @@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.EXAMPLE_FILTER_JSON, ) - self.assertEqual(channel.result["code"], b"403") + self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_add_filter_non_local_user(self) -> None: @@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.hs.is_mine = _is_mine - self.assertEqual(channel.result["code"], b"403") + self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self) -> None: @@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self) -> None: @@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) - self.assertEqual(channel.result["code"], b"404") + self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode @@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.code, 400) # No ID also returns an invalid_id error def test_get_filter_no_id(self) -> None: @@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.code, 400) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 299b9d21e2..b0c8215744 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -26,7 +25,6 @@ from tests import unittest class IdentityTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -34,7 +32,6 @@ class IdentityTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) @@ -51,12 +48,12 @@ class IdentityTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] - params = { + request_data = { "id_server": "testis", "medium": "email", "address": "test@example.com", + "id_access_token": tok, } - request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") channel = self.make_request( b"POST", request_url, request_data, access_token=tok diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f4ea1209d9..e2a4d98275 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py
@@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import time import urllib.parse -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from unittest.mock import Mock from urllib.parse import urlencode @@ -41,7 +40,7 @@ from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless try: - import jwt + from authlib.jose import jwk, jwt HAS_JWT = True except ImportError: @@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config( { @@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config( { @@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_soft_logout(self) -> None: @@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we shouldn't be able to make requests without an access token channel = self.make_request(b"GET", TEST_URL) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") # log in as normal @@ -354,7 +353,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # Now try to hard logout this session channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( @@ -380,7 +379,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # Now try to hard log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_login_with_overly_long_device_id_fails(self) -> None: self.register_user("mickey", "cheese") @@ -399,7 +398,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "/_matrix/client/v3/login", - json.dumps(body).encode("utf8"), + body, custom_headers=None, ) @@ -841,7 +840,7 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertIn(b"SSO account deactivated", channel.result["body"]) -@skip_unless(HAS_JWT, "requires jwt") +@skip_unless(HAS_JWT, "requires authlib") class JWTTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -866,11 +865,9 @@ class JWTTestCase(unittest.HomeserverTestCase): return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) - if isinstance(result, bytes): - return result.decode("ascii") - return result + header = {"alg": self.jwt_algorithm} + result: bytes = jwt.encode(header, payload, secret) + return result.decode("ascii") def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} @@ -880,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -899,25 +896,26 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Signature has expired" + channel.json_body["error"], + "JWT validation failed: expired_token: The token is expired", ) def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - "JWT validation failed: The token is not yet valid (nbf)", + "JWT validation failed: invalid_token: The token is not valid yet", ) def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @@ -926,30 +924,31 @@ class JWTTestCase(unittest.HomeserverTestCase): """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid issuer" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "iss"', ) # Not providing an issuer. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - 'JWT validation failed: Token is missing the "iss" claim', + 'JWT validation failed: missing_claim: Missing "iss" claim', ) def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) @@ -957,52 +956,54 @@ class JWTTestCase(unittest.HomeserverTestCase): """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid audience. channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "aud"', ) # Not providing an audience. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - 'JWT validation failed: Token is missing the "aud" claim', + 'JWT validation failed: missing_claim: Missing "aud" claim', ) def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "aud"', ) def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -1010,7 +1011,7 @@ class JWTTestCase(unittest.HomeserverTestCase): # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # RSS256, with a public key configured in synapse as "jwt_secret", and tokens # signed by the private key. -@skip_unless(HAS_JWT, "requires jwt") +@skip_unless(HAS_JWT, "requires authlib") class JWTPubKeyTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, @@ -1071,11 +1072,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") - if isinstance(result, bytes): - return result.decode("ascii") - return result + header = {"alg": "RS256"} + if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"): + secret = jwk.dumps(secret, kty="RSA") + result: bytes = jwt.encode(header, payload, secret) + return result.decode("ascii") def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} @@ -1084,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -1150,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" @@ -1164,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" @@ -1178,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" @@ -1192,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice @@ -1206,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) @skip_unless(HAS_OIDC, "requires OIDC") diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py new file mode 100644
index 0000000000..a9da00665e --- /dev/null +++ b/tests/rest/client/test_models.py
@@ -0,0 +1,53 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from pydantic import ValidationError + +from synapse.rest.client.models import EmailRequestTokenBody + + +class EmailRequestTokenBodyTestCase(unittest.TestCase): + base_request = { + "client_secret": "hunter2", + "email": "alice@wonderland.com", + "send_attempt": 1, + } + + def test_token_required_if_id_server_provided(self) -> None: + with self.assertRaises(ValidationError): + EmailRequestTokenBody.parse_obj( + { + **self.base_request, + "id_server": "identity.wonderland.com", + } + ) + with self.assertRaises(ValidationError): + EmailRequestTokenBody.parse_obj( + { + **self.base_request, + "id_server": "identity.wonderland.com", + "id_access_token": None, + } + ) + + def test_token_typechecked_when_id_server_provided(self) -> None: + with self.assertRaises(ValidationError): + EmailRequestTokenBody.parse_obj( + { + **self.base_request, + "id_server": "identity.wonderland.com", + "id_access_token": 1337, + } + ) diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3a74d2e96c..e19d21d6ee 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -89,7 +88,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_too_short(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request_data = {"username": "kermit", "password": "shorty"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -100,7 +99,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_digit(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request_data = {"username": "kermit", "password": "longerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -111,7 +110,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_symbol(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request_data = {"username": "kermit", "password": "l0ngerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -122,7 +121,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_uppercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request_data = {"username": "kermit", "password": "l0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -133,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_lowercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -144,7 +143,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_compliant(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request_data = {"username": "kermit", "password": "L0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has @@ -161,16 +160,14 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", compliant_password) tok = self.login("kermit", compliant_password) - request_data = json.dumps( - { - "new_password": not_compliant_password, - "auth": { - "password": compliant_password, - "type": LoginType.PASSWORD, - "user": user_id, - }, - } - ) + request_data = { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } channel = self.make_request( "POST", "/_matrix/client/r0/account/password", diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 77c3ced42e..8de5a342ae 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py
@@ -13,6 +13,8 @@ # limitations under the License. """Tests REST events for /profile paths.""" +import urllib.parse +from http import HTTPStatus from typing import Any, Dict, Optional from twisted.test.proto_helpers import MemoryReactor @@ -49,6 +51,12 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "owner") + def test_get_displayname_rejects_bad_username(self) -> None: + channel = self.make_request( + "GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname" + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + def test_set_displayname(self) -> None: channel = self.make_request( "PUT", @@ -145,18 +153,22 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def _get_displayname(self, name: Optional[str] = None) -> str: + def _get_displayname(self, name: Optional[str] = None) -> Optional[str]: channel = self.make_request( "GET", "/profile/%s/displayname" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) - return channel.json_body["displayname"] + # FIXME: If a user has no displayname set, Synapse returns 200 and omits a + # displayname from the response. This contradicts the spec, see #13137. + return channel.json_body.get("displayname") - def _get_avatar_url(self, name: Optional[str] = None) -> str: + def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]: channel = self.make_request( "GET", "/profile/%s/avatar_url" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) + # FIXME: If a user has no avatar set, Synapse returns 200 and omits an + # avatar_url from the response. This contradicts the spec, see #13137. return channel.json_body.get("avatar_url") @unittest.override_config({"max_avatar_size": 50}) diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 7401b5e0c0..be4c67d68e 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py
@@ -76,12 +76,12 @@ class RedactionsTestCase(HomeserverTestCase): path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) channel = self.make_request("POST", path, content={}, access_token=access_token) - self.assertEqual(int(channel.result["code"]), expect_code) + self.assertEqual(channel.code, expect_code) return channel.json_body def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: channel = self.make_request("GET", "sync", access_token=self.mod_access_token) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index afb08b2736..b781875d52 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py
@@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime -import json import os from typing import Any, Dict, List, Tuple @@ -62,15 +61,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps( - {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = { + "username": "as_user_kermit", + "type": APP_SERVICE_REGISTRATION_TYPE, + } channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) @@ -85,49 +85,46 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps({"username": "as_user_kermit"}) + request_data = {"username": "as_user_kermit"} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, 400, msg=channel.result) def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists - request_data = json.dumps( - {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) def test_POST_bad_password(self) -> None: - request_data = json.dumps({"username": "kermit", "password": 666}) + request_data = {"username": "kermit", "password": 666} channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, 400, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self) -> None: - request_data = json.dumps({"username": 777, "password": "monkey"}) + request_data = {"username": 777, "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, 400, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self) -> None: user_id = "@kermit:test" device_id = "frogfone" - params = { + request_data = { "username": "kermit", "password": "monkey", "device_id": device_id, "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) det_data = { @@ -135,17 +132,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) def test_POST_disabled_registration(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request_data = {"username": "kermit", "password": "monkey"} self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -156,7 +153,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self) -> None: @@ -164,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @@ -174,40 +171,39 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", url, b"{}") if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self) -> None: for i in range(0, 6): - params = { + request_data = { "username": "kermit" + str(i), "password": "monkey", "device_id": "frogfone", "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"registration_requires_token": True}) def test_POST_registration_requires_token(self) -> None: @@ -234,8 +230,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } # Request without auth to get flows and session - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one # flow would be a subset of another flow. @@ -251,9 +247,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -262,14 +257,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.DUMMY, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, params) det_data = { "user_id": f"@{username}:{self.hs.hostname}", "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 @@ -290,7 +284,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Test with token param missing (invalid) @@ -298,22 +292,22 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.REGISTRATION_TOKEN, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -337,9 +331,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} # Do 2 requests without auth to get two session IDs - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with session1 and check `pending` is 1 @@ -348,9 +342,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Repeat request to make sure pending isn't increased again - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) pending = self.get_success( store.db_pool.simple_select_one_onecol( "registration_tokens", @@ -366,14 +360,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session2, } - channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params2) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check pending=0 and completed=1 res = self.get_success( store.db_pool.simple_select_one( @@ -386,8 +380,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(res["completed"], 1) # Check auth still fails when using token with session2 - channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params2) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -411,7 +405,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) params: JsonDict = {"username": "kermit", "password": "monkey"} # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Check authentication fails with expired token @@ -420,8 +414,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -435,7 +429,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Check authentication succeeds - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -460,9 +454,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do 2 requests without auth to get two session IDs params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with both sessions @@ -471,18 +465,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) params2["auth"] = { "type": LoginType.REGISTRATION_TOKEN, "token": token, "session": session2, } - self.make_request(b"POST", self.url, json.dumps(params2)) + self.make_request(b"POST", self.url, params2) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check `result` of registration token stage for session1 is `True` result1 = self.get_success( @@ -550,7 +544,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do request without auth to get a session ID params: JsonDict = {"username": "kermit", "password": "monkey"} - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Use token @@ -559,7 +553,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - self.make_request(b"POST", self.url, json.dumps(params)) + self.make_request(b"POST", self.url, params) # Delete token self.get_success( @@ -576,7 +570,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_advertised_flows(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we only expect the dummy flow @@ -592,14 +586,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "require_at_registration": True, }, "account_threepid_delegates": { - "email": "https://id_server", "msisdn": "https://id_server", }, + "email": {"notif_from": "Synapse <synapse@example.com>"}, } ) def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] self.assertCountEqual( @@ -631,7 +625,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_no_msisdn_email_required(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we expect all four combinations of 3pid @@ -803,13 +797,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -827,15 +821,14 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = {"user_id": user_id} - request_data = json.dumps(params) + request_data = {"user_id": user_id} channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_manual_expire(self) -> None: user_id = self.register_user("kermit", "monkey") @@ -845,19 +838,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -870,25 +862,24 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Try to log the user out channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -963,7 +954,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -981,7 +972,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Move 1 day forward. Try to renew with the same token again. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1001,14 +992,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_renewal_invalid_token(self) -> None: # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"404", channel.result) + self.assertEqual(channel.code, 404, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1032,7 +1023,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1041,16 +1032,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) @@ -1107,7 +1096,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1187,7 +1176,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["valid"], True) def test_GET_token_invalid(self) -> None: @@ -1196,7 +1185,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["valid"], False) @override_config( @@ -1212,10 +1201,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): ) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -1223,4 +1212,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 62e4db23ef..651f4f415d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py
@@ -728,6 +728,7 @@ class RelationsTestCase(BaseRelationsTestCase): class RelationPaginationTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_basic_paginate_relations(self) -> None: """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -799,7 +800,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): ) expected_event_ids.append(channel.json_body["event_id"]) - prev_token = "" + prev_token: Optional[str] = "" found_event_ids: List[str] = [] for _ in range(20): from_token = "" @@ -998,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6) + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) def test_annotation_to_annotation(self) -> None: """Any relation to an annotation should be ignored.""" @@ -1034,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) def test_thread(self) -> None: """ @@ -1059,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): participated, bundled_aggregations.get("current_user_participated") ) # The latest thread event has some fields that don't matter. + self.assertIn("latest_event", bundled_aggregations) self.assert_dict( { "content": { @@ -1071,28 +1073,28 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): "sender": self.user2_id, "type": "m.room.test", }, - bundled_aggregations.get("latest_event"), + bundled_aggregations["latest_event"], ) return assert_thread # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # # Note that this re-uses some cached values, so the total number of # queries is much smaller. self._test_bundled_aggregations( - RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token + RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token ) # A user with no interactions with the thread: they have not participated. user3_id, user3_token = self._create_user("charlie") self.helper.join(self.room, user=user3_id, tok=user3_token) self._test_bundled_aggregations( - RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token + RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token ) def test_thread_with_bundled_aggregations_for_latest(self) -> None: @@ -1111,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertEqual(2, bundled_aggregations.get("count")) self.assertTrue(bundled_aggregations.get("current_user_participated")) # The latest thread event has some fields that don't matter. + self.assertIn("latest_event", bundled_aggregations) self.assert_dict( { "content": { @@ -1123,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): "sender": self.user_id, "type": "m.room.test", }, - bundled_aggregations.get("latest_event"), + bundled_aggregations["latest_event"], ) # Check the unsigned field on the latest event. self.assert_dict( @@ -1139,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) def test_nested_thread(self) -> None: """ diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 20a259fc43..7cb1017a4a 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py
@@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json - from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -77,11 +75,6 @@ class ReportEventTestCase(unittest.HomeserverTestCase): def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( - "POST", - self.report_path, - json.dumps(data), - access_token=self.other_user_tok, - ) - self.assertEqual( - response_status, int(channel.result["code"]), msg=channel.result["body"] + "POST", self.report_path, data, access_token=self.other_user_tok ) + self.assertEqual(response_status, channel.code, msg=channel.result["body"]) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c113354..9c8c1889d3 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, create_requester from synapse.util import Clock from synapse.visibility import filter_events_for_client @@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): message_handler = self.hs.get_message_handler() create_event = self.get_success( message_handler.get_room_data( - self.user_id, room_id, EventTypes.Create, state_key="" + create_requester(self.user_id), room_id, EventTypes.Create, state_key="" ) ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index f523d89b8f..c7eb88d33f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -18,10 +18,14 @@ """Tests REST events for /rooms paths.""" import json -from typing import Any, Dict, Iterable, List, Optional +from http import HTTPStatus +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from unittest.mock import Mock, call from urllib import parse as urlparse +from parameterized import param, parameterized +from typing_extensions import Literal + from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,7 +34,9 @@ from synapse.api.constants import ( EventContentFields, EventTypes, Membership, + PublicRoomsFilterFields, RelationTypes, + RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus @@ -42,6 +48,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest +from tests.http.server._base import make_request_with_cancellation_test from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -98,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -106,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -128,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_topic_perms(self) -> None: topic_content = b'{"topic":"My Topic Name"}' @@ -159,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -188,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -214,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room: str, members: Iterable = frozenset(), expect_code: int = 200 @@ -303,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( room=room, src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) def test_joined_permissions(self) -> None: @@ -336,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of other, expect 403 @@ -345,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of self, expect 200 @@ -365,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.INVITE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( @@ -373,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # It is always valid to LEAVE if you've already left (currently.) @@ -382,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember @@ -399,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.BAN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -409,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to invite: Must never happen. @@ -418,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.INVITE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -428,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase): src=other, targ=other, membership=Membership.JOIN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -438,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to knock: Must never happen. @@ -447,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.KNOCK, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -457,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -467,10 +474,53 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.LEAVE, - expect_code=200, + expect_code=HTTPStatus.OK, ) +class RoomStateTestCase(RoomBase): + """Tests /rooms/$room_id/state.""" + + user_id = "@sid1:red" + + def test_get_state_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_state_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/state" % room_id, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertCountEqual( + [state_event["type"] for state_event in channel.json_list], + { + "m.room.create", + "m.room.power_levels", + "m.room.join_rules", + "m.room.member", + "m.room.history_visibility", + }, + ) + + def test_get_state_event_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state/$event_type` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_state_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertEqual(channel.json_body, {"membership": "join"}) + + class RoomsMemberListTestCase(RoomBase): """Tests /rooms/$room_id/members/list REST events.""" @@ -481,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self) -> None: channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self) -> None: room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self) -> None: """ @@ -501,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -510,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self) -> None: """ @@ -523,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self) -> None: """ @@ -544,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -563,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self) -> None: room_creator = "@some_other_guy:red" @@ -579,17 +629,73 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + + def test_get_member_list_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_get_member_list_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/members" % room_id, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + channel.json_body["chunk"][0].items(), + ) + + def test_get_member_list_with_at_token_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request.""" + room_id = self.helper.create_room_as(self.user_id) + + # first sync to get an at token + channel = self.make_request("GET", "/sync") + self.assertEqual(HTTPStatus.OK, channel.code) + sync_token = channel.json_body["next_batch"] + + channel = make_request_with_cancellation_test( + "test_get_member_list_with_at_token_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/members?at=%s" % (room_id, sync_token), + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + channel.json_body["chunk"][0].items(), + ) class RoomsCreateTestCase(RoomBase): @@ -601,19 +707,34 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(44, channel.resource_usage.db_txn_count) + + def test_post_room_initial_state(self) -> None: + # POST with initial_state config key, expect new room id + channel = self.make_request( + "POST", + "/createRoom", + b'{"initial_state":[{"type": "m.bridge", "content": {}}]}', + ) + + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertTrue("room_id" in channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(50, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self) -> None: # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self) -> None: @@ -621,16 +742,16 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self) -> None: # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) def test_post_room_invitees_invalid_mxid(self) -> None: # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -638,7 +759,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self) -> None: @@ -649,20 +770,18 @@ class RoomsCreateTestCase(RoomBase): # Build the request's content. We use local MXIDs because invites over federation # are more difficult to mock. - content = json.dumps( - { - "invite": [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - } - ).encode("utf8") + content = { + "invite": [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + } # Test that the invites are correctly ratelimited. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) self.assertEqual( "Cannot invite so many users at once", channel.json_body["error"], @@ -675,11 +794,13 @@ class RoomsCreateTestCase(RoomBase): # Test that the invites aren't ratelimited anymore. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. + + In this test, we use the deprecated API in which callbacks return a bool. """ async def user_may_join_room( @@ -697,10 +818,55 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly bypassed + when creating a new room. + + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + """ + + async def user_may_join_room_codes( + mxid: str, + room_id: str, + is_invite: bool, + ) -> Codes: + return Codes.CONSENT_NOT_GIVEN + + join_mock = Mock(side_effect=user_may_join_room_codes) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + self.assertEqual(join_mock.call_count, 0) + + # Now change the return value of the callback to deny any join. Since we're + # creating the room, despite the return value, we should be able to join. + async def user_may_join_room_tuple( + mxid: str, + room_id: str, + is_invite: bool, + ) -> Tuple[Codes, dict]: + return Codes.INCOMPATIBLE_ROOM_VERSION, {} + + join_mock.side_effect = user_may_join_room_tuple + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual(join_mock.call_count, 0) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -715,54 +881,68 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self) -> None: # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_topic(self) -> None: # nothing should be there channel = self.make_request("GET", self.path) - self.assertEqual(404, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self) -> None: # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -778,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -802,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_members_self(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % ( @@ -813,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEqual(expected_response, channel.json_body) @@ -831,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self) -> None: @@ -850,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) @@ -911,9 +1105,11 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. + + This test uses the deprecated API, in which callbacks return booleans. """ # Register a dummy callback. Make it allow all room joins for now. @@ -926,6 +1122,8 @@ class RoomJoinTestCase(RoomBase): ) -> bool: return return_value + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) @@ -966,7 +1164,92 @@ class RoomJoinTestCase(RoomBase): # Now make the callback deny all room joins, and check that a join actually fails. return_value = False - self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + self.helper.join( + self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2 + ) + + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly called + and blocks room joins when needed. + + This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`. + """ + + # Register a dummy callback. Make it allow all room joins for now. + return_value: Union[ + Literal["NOT_SPAM"], Tuple[Codes, dict], Codes + ] = synapse.module_api.NOT_SPAM + + async def user_may_join_room( + userid: str, + room_id: str, + is_invited: bool, + ) -> Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes]: + return return_value + + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) + + # Join a first room, without being invited to it. + self.helper.join(self.room1, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room1, + False, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Join a second room, this time with an invite for it. + self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1) + self.helper.join(self.room2, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room2, + True, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + # We pick an arbitrary Codes rather than the default `Codes.FORBIDDEN`. + return_value = Codes.CONSENT_NOT_GIVEN + self.helper.invite(self.room3, self.user1, self.user2, tok=self.tok1) + self.helper.join( + self.room3, + self.user2, + expect_code=HTTPStatus.FORBIDDEN, + expect_errcode=return_value, + tok=self.tok2, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + # As above, with the experimental extension that lets us return dictionaries. + return_value = (Codes.BAD_ALIAS, {"another_field": "12345"}) + self.helper.join( + self.room3, + self.user2, + expect_code=HTTPStatus.FORBIDDEN, + expect_errcode=return_value[0], + tok=self.tok2, + expect_additional_fields=return_value[1], + ) class RoomJoinRatelimitTestCase(RoomBase): @@ -1016,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1081,40 +1364,153 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_messages_sent(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + + @parameterized.expand( + [ + # Allow + param( + name="NOT_SPAM", + value="NOT_SPAM", + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + param( + name="False", + value=False, + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + # Block + param( + name="scalene string", + value="ANY OTHER STRING", + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_FORBIDDEN"}, + ), + param( + name="True", + value=True, + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_FORBIDDEN"}, + ), + param( + name="Code", + value=Codes.LIMIT_EXCEEDED, + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_LIMIT_EXCEEDED"}, + ), + param( + name="Tuple", + value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}), + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={ + "errcode": "M_SERVER_NOT_TRUSTED", + "additional_field": "12345", + }, + ), + ] + ) + def test_spam_checker_check_event_for_spam( + self, + name: str, + value: Union[str, bool, Codes, Tuple[Codes, JsonDict]], + expected_code: int, + expected_fields: dict, + ) -> None: + class SpamCheck: + mock_return_value: Union[ + str, bool, Codes, Tuple[Codes, JsonDict], bool + ] = "NOT_SPAM" + mock_content: Optional[JsonDict] = None + + async def check_event_for_spam( + self, + event: synapse.events.EventBase, + ) -> Union[str, Codes, Tuple[Codes, JsonDict], bool]: + self.mock_content = event.content + return self.mock_return_value + + spam_checker = SpamCheck() + + self.hs.get_spam_checker()._check_event_for_spam_callbacks.append( + spam_checker.check_event_for_spam + ) + + # Inject `value` as mock_return_value + spam_checker.mock_return_value = value + path = "/rooms/%s/send/m.room.message/check_event_for_spam_%s" % ( + urlparse.quote(self.room_id), + urlparse.quote(name), + ) + body = "test-%s" % name + content = '{"body":"%s","msgtype":"m.text"}' % body + channel = self.make_request("PUT", path, content) + + # Check that the callback has witnessed the correct event. + self.assertIsNotNone(spam_checker.mock_content) + if ( + spam_checker.mock_content is not None + ): # Checked just above, but mypy doesn't know about that. + self.assertEqual( + spam_checker.mock_content["body"], body, spam_checker.mock_content + ) + + # Check that we have the correct result. + self.assertEqual(expected_code, channel.code, msg=channel.result["body"]) + for expected_key, expected_value in expected_fields.items(): + self.assertEqual( + channel.json_body.get(expected_key, None), + expected_value, + "Field %s absent or invalid " % expected_key, + ) class RoomPowerLevelOverridesTestCase(RoomBase): @@ -1239,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_normal_user_can_not_post_state_event(self) -> None: # Given I am a normal member of a room @@ -1253,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed because state events require PL>=50 - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " "user_level (0) < send_level (50)", @@ -1280,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1308,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1336,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (1)", @@ -1367,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): # Then I am not allowed because the public_chat config does not # affect this room, because this room is a private_chat - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (50)", @@ -1386,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase): def test_initial_sync(self) -> None: channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(self.room_id, channel.json_body["room_id"]) self.assertEqual("join", channel.json_body["membership"]) @@ -1429,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1440,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1479,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) @@ -1507,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) @@ -1524,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) @@ -1652,14 +2048,97 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): def test_restricted_no_auth(self) -> None: channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) def test_restricted_auth(self) -> None: self.register_user("user", "pass") tok = self.login("user", "pass") channel = self.make_request("GET", self.url, access_token=tok) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + +class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + + config = self.default_config() + config["allow_public_rooms_without_auth"] = True + self.hs = self.setup_test_homeserver(config=config) + self.url = b"/_matrix/client/r0/publicRooms" + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + user = self.register_user("alice", "pass") + self.token = self.login(user, "pass") + + # Create a room + self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=self.token, + ) + # Create a space + self.helper.create_room_as( + user, + is_public=True, + extra_content={ + "visibility": "public", + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}, + }, + tok=self.token, + ) + + def make_public_rooms_request( + self, room_types: Union[List[Union[str, None]], None] + ) -> Tuple[List[Dict[str, Any]], int]: + channel = self.make_request( + "POST", + self.url, + {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}, + self.token, + ) + chunk = channel.json_body["chunk"] + count = channel.json_body["total_room_count_estimate"] + + self.assertEqual(len(chunk), count) + + return chunk, count + + def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None: + chunk, count = self.make_public_rooms_request(None) + + self.assertEqual(count, 2) + + def test_returns_only_rooms_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request([None]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("room_type", None), None) + + def test_returns_only_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space"]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("room_type", None), "m.space") + + def test_returns_both_rooms_and_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space", None]) + + self.assertEqual(count, 2) + + def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None: + chunk, count = self.make_public_rooms_request([]) + + self.assertEqual(count, 2) class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): @@ -1686,7 +2165,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): "Simple test for searching rooms over federation" self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1694,7 +2173,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined] "testserv", @@ -1711,11 +2190,11 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # The `get_public_rooms` should be called again if the first call fails # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] - HttpResponseException(404, "Not Found", b""), + HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), make_awaitable({}), ) - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1723,7 +2202,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined] [ @@ -1769,21 +2248,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): # Set a profile for the test user self.displayname = "test user" - data = {"displayname": self.displayname} - request_data = json.dumps(data) + request_data = {"displayname": self.displayname} channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) def test_per_room_profile_forbidden(self) -> None: - data = {"membership": "join", "displayname": "other test user"} - request_data = json.dumps(data) + request_data = {"membership": "join", "displayname": "other test user"} channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" @@ -1791,7 +2268,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1799,7 +2276,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) @@ -1833,7 +2310,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1847,7 +2324,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1861,7 +2338,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1875,7 +2352,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1887,7 +2364,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1899,7 +2376,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1918,7 +2395,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1930,7 +2407,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ), access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_content = channel.json_body @@ -1978,7 +2455,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2008,7 +2485,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2043,7 +2520,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2123,16 +2600,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_labels(self) -> None: """Test that we can filter by a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2160,16 +2635,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2209,16 +2682,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /search request. """ - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2391,7 +2862,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["chunk"] @@ -2496,7 +2967,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2562,7 +3033,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=invited_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2663,8 +3134,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -2693,8 +3163,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -2720,7 +3189,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - json.dumps(content), + content, access_token=self.room_owner_tok, ) self.assertEqual(channel.code, expected_code, channel.result) @@ -2845,11 +3314,16 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_threepid_invite_spamcheck(self) -> None: + def test_threepid_invite_spamcheck_deprecated(self) -> None: + """ + Test allowing/blocking threepid invites with a spam-check module. + + In this test, we use the deprecated API in which callbacks return a bool. + """ # Mock a few functions to prevent the test from failing due to failing to talk to - # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we + # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable(0)) + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_identity_handler().lookup_3pid = Mock( return_value=make_awaitable(None), @@ -2901,3 +3375,107 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() + + def test_threepid_invite_spamcheck(self) -> None: + """ + Test allowing/blocking threepid invites with a spam-check module. + + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.""" + # Mock a few functions to prevent the test from failing due to failing to talk to + # a remote IS. We keep the mock for make_and_store_3pid_invite around so we + # can check its call_count later on during the test. + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock + self.hs.get_identity_handler().lookup_3pid = Mock( + return_value=make_awaitable(None), + ) + + # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it + # allow everything for now. + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + mock = Mock( + return_value=make_awaitable(synapse.module_api.NOT_SPAM), + spec=lambda *x: None, + ) + self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) + + # Send a 3PID invite into the room and check that it succeeded. + email_to_invite = "teresa@example.com" + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Check that the callback was called with the right params. + mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) + + # Check that the call to send the invite was made. + make_invite_mock.assert_called_once() + + # Now change the return value of the callback to deny any invite and test that + # we can't send the invite. We pick an arbitrary error code to be able to check + # that the same code has been returned + mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN) + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() + + # Run variant with `Tuple[Codes, dict]`. + mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"})) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT) + self.assertEqual(channel.json_body["field"], "value") + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() + + def test_400_missing_param_without_id_access_token(self) -> None: + """ + Test that a 3pid invite request returns 400 M_MISSING_PARAM + if we do not include id_access_token. + """ + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "medium": "email", + "address": "teresa@example.com", + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM") diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a28..c807a37bc2 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import ( room_upgrade_rest_servlet, ) from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import UserID, create_requester from synapse.util import Clock from tests import unittest @@ -97,7 +97,12 @@ class RoomTestCase(_ShadowBannedBase): channel = self.make_request( "POST", "/rooms/%s/invite" % (room_id,), - {"id_server": "test", "medium": "email", "address": "test@test.test"}, + { + "id_server": "test", + "medium": "email", + "address": "test@test.test", + "id_access_token": "anytoken", + }, access_token=self.banned_access_token, ) self.assertEqual(200, channel.code, channel.result) @@ -275,7 +280,7 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, + create_requester(self.banned_user_id), room_id, "m.room.member", self.banned_user_id, @@ -310,7 +315,7 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, + create_requester(self.banned_user_id), room_id, "m.room.member", self.banned_user_id, diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index e3efd1f1b0..0af643ecd9 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import ( KnockingStrippedStateEventHelperMixin, ) from tests.server import TimedOutException -from tests.unittest import override_config class FilterTestCase(unittest.HomeserverTestCase): @@ -390,6 +389,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + return self.setup_test_homeserver(config=config) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -408,7 +412,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Join the second user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - @override_config({"experimental_features": {"msc2285_enabled": True}}) def test_private_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -416,7 +419,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -425,7 +428,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's private read receipt self.assertIsNone(self._get_read_receipt()) - @override_config({"experimental_features": {"msc2285_enabled": True}}) def test_public_receipt_can_override_private(self) -> None: """ Sending a public read receipt to the same event which has a private read @@ -456,7 +458,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we did override the private read receipt self.assertNotEqual(self._get_read_receipt(), None) - @override_config({"experimental_features": {"msc2285_enabled": True}}) def test_private_receipt_cannot_override_public(self) -> None: """ Sending a private read receipt to the same event which has a public read @@ -543,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): config = super().default_config() config["experimental_features"] = { "msc2654_enabled": True, - "msc2285_enabled": True, } return config @@ -606,11 +606,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8") channel = self.make_request( "POST", f"/rooms/{self.room_id}/read_markers", - body, + {ReceiptTypes.READ: res["event_id"]}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -625,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok, ) @@ -701,7 +700,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(5) res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) - # Make sure both m.read and org.matrix.msc2285.read.private advance + # Make sure both m.read and m.read.private advance channel = self.make_request( "POST", f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", @@ -713,16 +712,21 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) self._check_unread_count(0) - # We test for both receipt types that influence notification counts - @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]) - def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None: + # We test for all three receipt types that influence notification counts + @parameterized.expand( + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ] + ) + def test_read_receipts_only_go_down(self, receipt_type: str) -> None: # Join the new user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @@ -733,18 +737,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Read last event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) self._check_unread_count(0) - # Make sure neither m.read nor org.matrix.msc2285.read.private make the + # Make sure neither m.read nor m.read.private make the # read receipt go up to an older event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}", {}, access_token=self.tok, ) @@ -949,3 +953,24 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase): self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"]) self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"]) + + def test_incremental_sync(self) -> None: + """Tests that activity in the room is properly filtered out of incremental + syncs. + """ + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + next_batch = channel.json_body["next_batch"] + + self.helper.send(self.excluded_room_id, tok=self.tok) + self.helper.send(self.included_room_id, tok=self.tok) + + channel = self.make_request( + "GET", + f"/sync?since={next_batch}", + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"]) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 5eb0f243f7..3325d43a2f 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError from synapse.api.room_versions import RoomVersion +from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase -from synapse.events.snapshot import EventContext from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin from synapse.rest.client import account, login, profile, room @@ -113,14 +113,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth( - origin: str, - event: EventBase, - context: EventContext, - *args: Any, - **kwargs: Any, - ) -> EventContext: - return context + async def _check_event_auth(origin: Any, event: Any, context: Any) -> None: + pass hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] @@ -161,7 +155,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) callback.assert_called_once() @@ -179,7 +173,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, channel.result) def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None: """ @@ -192,12 +186,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """ class NastyHackException(SynapseError): - def error_dict(self) -> JsonDict: + def error_dict(self, config: Optional[HomeServerConfig]) -> JsonDict: """ This overrides SynapseError's `error_dict` to nastily inject JSON into the error response. """ - result = super().error_dict() + result = super().error_dict(config) result["nasty"] = "very" return result @@ -217,7 +211,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): access_token=self.tok, ) # Check the error code - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, channel.result) # Check the JSON body has had the `nasty` key injected self.assertEqual( channel.json_body, @@ -266,7 +260,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {"x": "x"}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] # ... and check that it got modified @@ -275,7 +269,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") @@ -304,7 +298,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) orig_event_id = channel.json_body["event_id"] channel = self.make_request( @@ -321,7 +315,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) edited_event_id = channel.json_body["event_id"] # ... and check that they both got modified @@ -330,7 +324,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) ev = channel.json_body self.assertEqual(ev["content"]["body"], "ORIGINAL BODY") @@ -339,7 +333,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) ev = channel.json_body self.assertEqual(ev["content"]["body"], "EDITED BODY") @@ -385,7 +379,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] @@ -394,7 +388,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) self.assertIn("foo", channel.json_body["content"].keys()) self.assertEqual(channel.json_body["content"]["foo"], "bar") diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 98c1039d33..5e7bf97482 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py
@@ -48,10 +48,14 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.helper.join(self.room_id, self.other, tok=self.other_token) def _upgrade_room( - self, token: Optional[str] = None, room_id: Optional[str] = None + self, + token: Optional[str] = None, + room_id: Optional[str] = None, + expire_cache: bool = True, ) -> FakeChannel: - # We never want a cached response. - self.reactor.advance(5 * 60 + 1) + if expire_cache: + # We don't want a cached response. + self.reactor.advance(5 * 60 + 1) if room_id is None: room_id = self.room_id @@ -72,9 +76,24 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) - def test_not_in_room(self) -> None: + new_room_id = channel.json_body["replacement_room"] + + # Check that the tombstone event points to the new room. + tombstone_event = self.get_success( + self.hs.get_storage_controllers().state.get_current_state_event( + self.room_id, EventTypes.Tombstone, "" + ) + ) + self.assertIsNotNone(tombstone_event) + self.assertEqual(new_room_id, tombstone_event.content["replacement_room"]) + + # Check that the new room exists. + room = self.get_success(self.store.get_room(new_room_id)) + self.assertIsNotNone(room) + + def test_never_in_room(self) -> None: """ - Upgrading a room should work fine. + A user who has never been in the room cannot upgrade the room. """ # The user isn't in the room. roomless = self.register_user("roomless", "pass") @@ -83,6 +102,16 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(roomless_token) self.assertEqual(403, channel.code, channel.result) + def test_left_room(self) -> None: + """ + A user who is no longer in the room cannot upgrade the room. + """ + # Remove the user from the room. + self.helper.leave(self.room_id, self.creator, tok=self.creator_token) + + channel = self._upgrade_room(self.creator_token) + self.assertEqual(403, channel.code, channel.result) + def test_power_levels(self) -> None: """ Another user can upgrade the room if their power level is increased. @@ -297,3 +326,47 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.assertEqual( create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type ) + + def test_second_upgrade_from_same_user(self) -> None: + """A second room upgrade from the same user is deduplicated.""" + channel1 = self._upgrade_room() + self.assertEqual(200, channel1.code, channel1.result) + + channel2 = self._upgrade_room(expire_cache=False) + self.assertEqual(200, channel2.code, channel2.result) + + self.assertEqual( + channel1.json_body["replacement_room"], + channel2.json_body["replacement_room"], + ) + + def test_second_upgrade_after_delay(self) -> None: + """A second room upgrade is not deduplicated after some time has passed.""" + channel1 = self._upgrade_room() + self.assertEqual(200, channel1.code, channel1.result) + + channel2 = self._upgrade_room(expire_cache=True) + self.assertEqual(200, channel2.code, channel2.result) + + self.assertNotEqual( + channel1.json_body["replacement_room"], + channel2.json_body["replacement_room"], + ) + + def test_second_upgrade_from_different_user(self) -> None: + """A second room upgrade from a different user is blocked.""" + channel = self._upgrade_room() + self.assertEqual(200, channel.code, channel.result) + + channel = self._upgrade_room(self.other_token, expire_cache=False) + self.assertEqual(400, channel.code, channel.result) + + def test_first_upgrade_does_not_block_second(self) -> None: + """A second room upgrade is not blocked when a previous upgrade attempt was not + allowed. + """ + channel = self._upgrade_room(self.other_token) + self.assertEqual(403, channel.code, channel.result) + + channel = self._upgrade_room(expire_cache=False) + self.assertEqual(200, channel.code, channel.result) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index a0788b1bb0..dd26145bf8 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -41,6 +41,7 @@ from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict @@ -135,11 +136,11 @@ class RestHelper: self.site, "POST", path, - json.dumps(content).encode("utf8"), + content, custom_headers=custom_headers, ) - assert channel.result["code"] == b"%d" % expect_code, channel.result + assert channel.code == expect_code, channel.result self.auth_user_id = temp_id if expect_code == HTTPStatus.OK: @@ -171,6 +172,8 @@ class RestHelper: expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, + expect_errcode: Optional[Codes] = None, + expect_additional_fields: Optional[dict] = None, ) -> None: self.change_membership( room=room, @@ -180,6 +183,8 @@ class RestHelper: appservice_user_id=appservice_user_id, membership=Membership.JOIN, expect_code=expect_code, + expect_errcode=expect_errcode, + expect_additional_fields=expect_additional_fields, ) def knock( @@ -205,14 +210,12 @@ class RestHelper: self.site, "POST", path, - json.dumps(data).encode("utf8"), + data, ) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -263,6 +266,7 @@ class RestHelper: appservice_user_id: Optional[str] = None, expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, + expect_additional_fields: Optional[dict] = None, ) -> None: """ Send a membership state event into a room. @@ -303,14 +307,12 @@ class RestHelper: self.site, "PUT", path, - json.dumps(data).encode("utf8"), + data, ) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -323,6 +325,21 @@ class RestHelper: channel.result["body"], ) + if expect_additional_fields is not None: + for expect_key, expect_value in expect_additional_fields.items(): + assert expect_key in channel.json_body, "Expected field %s, got %s" % ( + expect_key, + channel.json_body, + ) + assert ( + channel.json_body[expect_key] == expect_value + ), "Expected: %s at %s, got: %s, resp: %s" % ( + expect_value, + expect_key, + channel.json_body[expect_key], + channel.json_body, + ) + self.auth_user_id = temp_id def send( @@ -371,15 +388,13 @@ class RestHelper: self.site, "PUT", path, - json.dumps(content or {}).encode("utf8"), + content or {}, custom_headers=custom_headers, ) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -428,11 +443,9 @@ class RestHelper: channel = make_request(self.hs.get_reactor(), self.site, method, path, content) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -524,7 +537,7 @@ class RestHelper: assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index ea9e5889bf..1062081a06 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py
@@ -370,6 +370,64 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "รณ", "og:description": "Some text."}) + def test_twitter_tag(self) -> None: + """Twitter card tags should be used if nothing else is available.""" + html = b""" + <html> + <meta name="twitter:card" content="summary"> + <meta name="twitter:description" content="Description"> + <meta name="twitter:site" content="@matrixdotorg"> + </html> + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + self.assertEqual( + og, + { + "og:title": None, + "og:description": "Description", + "og:site_name": "@matrixdotorg", + }, + ) + + # But they shouldn't override Open Graph values. + html = b""" + <html> + <meta name="twitter:card" content="summary"> + <meta name="twitter:description" content="Description"> + <meta property="og:description" content="Real Description"> + <meta name="twitter:site" content="@matrixdotorg"> + <meta property="og:site_name" content="matrix.org"> + </html> + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + self.assertEqual( + og, + { + "og:title": None, + "og:description": "Real Description", + "og:site_name": "matrix.org", + }, + ) + + def test_nested_nodes(self) -> None: + """A body with some nested nodes. Tests that we iterate over children + in the right order (and don't reverse the order of the text).""" + html = b""" + <a href="somewhere">Welcome <b>the bold <u>and underlined text <svg> + with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a> + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + self.assertEqual( + og, + { + "og:title": None, + "og:description": "Welcome\n\nthe bold\n\nand underlined text\n\nand\n\nsome\n\ntail text", + }, + ) + class MediaEncodingTestCase(unittest.TestCase): def test_meta_charset(self) -> None: diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 7204b2dfe0..d18fc13c21 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py
@@ -23,11 +23,13 @@ from urllib import parse import attr from parameterized import parameterized, parameterized_class from PIL import Image as Image +from typing_extensions import Literal from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor +from synapse.api.errors import Codes from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import make_deferred_yieldable @@ -124,7 +126,9 @@ class _TestImage: expected_scaled: The expected bytes from scaled thumbnailing, or None if test should just check for a valid image returned. expected_found: True if the file should exist on the server, or False if - a 404 is expected. + a 404/400 is expected. + unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or + False if the thumbnailing should succeed or a normal 404 is expected. """ data: bytes @@ -133,6 +137,7 @@ class _TestImage: expected_cropped: Optional[bytes] = None expected_scaled: Optional[bytes] = None expected_found: bool = True + unable_to_thumbnail: bool = False @parameterized_class( @@ -190,6 +195,7 @@ class _TestImage: b"image/gif", b".gif", expected_found=False, + unable_to_thumbnail=True, ), ), ], @@ -364,18 +370,29 @@ class MediaRepoTests(unittest.HomeserverTestCase): def test_thumbnail_crop(self) -> None: """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( - "crop", self.test_image.expected_cropped, self.test_image.expected_found + "crop", + self.test_image.expected_cropped, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) def test_thumbnail_scale(self) -> None: """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( - "scale", self.test_image.expected_scaled, self.test_image.expected_found + "scale", + self.test_image.expected_scaled, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) def test_invalid_type(self) -> None: """An invalid thumbnail type is never available.""" - self._test_thumbnail("invalid", None, False) + self._test_thumbnail( + "invalid", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} @@ -384,7 +401,12 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ Override the config to generate only scaled thumbnails, but request a cropped one. """ - self._test_thumbnail("crop", None, False) + self._test_thumbnail( + "crop", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} @@ -393,14 +415,22 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ Override the config to generate only cropped thumbnails, but request a scaled one. """ - self._test_thumbnail("scale", None, False) + self._test_thumbnail( + "scale", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) def test_thumbnail_repeated_thumbnail(self) -> None: """Test that fetching the same thumbnail works, and deleting the on disk thumbnail regenerates it. """ self._test_thumbnail( - "scale", self.test_image.expected_scaled, self.test_image.expected_found + "scale", + self.test_image.expected_scaled, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) if not self.test_image.expected_found: @@ -457,8 +487,24 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) def _test_thumbnail( - self, method: str, expected_body: Optional[bytes], expected_found: bool + self, + method: str, + expected_body: Optional[bytes], + expected_found: bool, + unable_to_thumbnail: bool = False, ) -> None: + """Test the given thumbnailing method works as expected. + + Args: + method: The thumbnailing method to use (crop, scale). + expected_body: The expected bytes from thumbnailing, or None if + test should just check for a valid image. + expected_found: True if the file should exist on the server, or False if + a 404/400 is expected. + unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or + False if the thumbnailing should succeed or a normal 404 is expected. + """ + params = "?width=32&height=32&method=" + method channel = make_request( self.reactor, @@ -481,6 +527,12 @@ class MediaRepoTests(unittest.HomeserverTestCase): if expected_found: self.assertEqual(channel.code, 200) + + self.assertEqual( + channel.headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), + [b"cross-origin"], + ) + if expected_body is not None: self.assertEqual( channel.result["body"], expected_body, channel.result["body"] @@ -488,6 +540,16 @@ class MediaRepoTests(unittest.HomeserverTestCase): else: # ensure that the result is at least some valid image Image.open(BytesIO(channel.result["body"])) + elif unable_to_thumbnail: + # A 400 with a JSON body. + self.assertEqual(channel.code, 400) + self.assertEqual( + channel.json_body, + { + "errcode": "M_UNKNOWN", + "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", + }, + ) else: # A 404 with a JSON body. self.assertEqual(channel.code, 404) @@ -549,10 +611,26 @@ class MediaRepoTests(unittest.HomeserverTestCase): [b"noindex, nofollow, noarchive, noimageindex"], ) + def test_cross_origin_resource_policy_header(self) -> None: + """ + Test that the Cross-Origin-Resource-Policy header is set to "cross-origin" + allowing web clients to embed media from the downloads API. + """ + channel = self._req(b"inline; filename=out" + self.test_image.extension) -class TestSpamChecker: + headers = channel.headers + + self.assertEqual( + headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), + [b"cross-origin"], + ) + + +class TestSpamCheckerLegacy: """A spam checker module that rejects all media that includes the bytes `evil`. + + Uses the legacy Spam-Checker API. """ def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None: @@ -593,7 +671,7 @@ class TestSpamChecker: return b"evil" in buf.getvalue() -class SpamCheckerTestCase(unittest.HomeserverTestCase): +class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase): servlets = [ login.register_servlets, admin.register_servlets, @@ -617,7 +695,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): { "spam_checker": [ { - "module": TestSpamChecker.__module__ + ".TestSpamChecker", + "module": TestSpamCheckerLegacy.__module__ + + ".TestSpamCheckerLegacy", "config": {}, } ] @@ -642,3 +721,62 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): self.helper.upload_media( self.upload_resource, data, tok=self.tok, expect_code=400 ) + + +EVIL_DATA = b"Some evil data" +EVIL_DATA_EXPERIMENT = b"Some evil data to trigger the experimental tuple API" + + +class SpamCheckerTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + admin.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + # Allow for uploading and downloading to/from the media repo + self.media_repo = hs.get_media_repository_resource() + self.download_resource = self.media_repo.children[b"download"] + self.upload_resource = self.media_repo.children[b"upload"] + + hs.get_module_api().register_spam_checker_callbacks( + check_media_file_for_spam=self.check_media_file_for_spam + ) + + async def check_media_file_for_spam( + self, file_wrapper: ReadableFileWrapper, file_info: FileInfo + ) -> Union[Codes, Literal["NOT_SPAM"]]: + buf = BytesIO() + await file_wrapper.write_chunks_to(buf.write) + + if buf.getvalue() == EVIL_DATA: + return Codes.FORBIDDEN + elif buf.getvalue() == EVIL_DATA_EXPERIMENT: + return (Codes.FORBIDDEN, {}) + else: + return "NOT_SPAM" + + def test_upload_innocent(self) -> None: + """Attempt to upload some innocent data that should be allowed.""" + self.helper.upload_media( + self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 + ) + + def test_upload_ban(self) -> None: + """Attempt to upload some data that includes bytes "evil", which should + get rejected by the spam checker. + """ + + self.helper.upload_media( + self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400 + ) + + self.helper.upload_media( + self.upload_resource, + EVIL_DATA_EXPERIMENT, + tok=self.tok, + expect_code=400, + ) diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index da325955f8..c0a2501742 100644 --- a/tests/rest/test_health.py +++ b/tests/rest/test_health.py
@@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus - from synapse.rest.health import HealthResource from tests import unittest @@ -26,5 +24,5 @@ class HealthCheckTests(unittest.HomeserverTestCase): def test_health(self) -> None: channel = self.make_request("GET", "/health", shorthand=False) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) self.assertEqual(channel.result["body"], b"OK") diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 11f78f52b8..2091b08d89 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py
@@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from http import HTTPStatus - from twisted.web.resource import Resource from synapse.rest.well_known import well_known_resource @@ -38,7 +36,7 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, { @@ -57,7 +55,29 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/client", shorthand=False ) - self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.code, 404) + + @unittest.override_config( + { + "public_baseurl": "https://tesths", + "default_identity_server": "https://testis", + "extra_well_known_client_content": {"custom": False}, + } + ) + def test_client_well_known_custom(self) -> None: + channel = self.make_request( + "GET", "/.well-known/matrix/client", shorthand=False + ) + + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, + { + "m.homeserver": {"base_url": "https://tesths/"}, + "m.identity_server": {"base_url": "https://testis"}, + "custom": False, + }, + ) @unittest.override_config({"serve_server_wellknown": True}) def test_server_well_known(self) -> None: @@ -65,7 +85,7 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, 200) self.assertEqual( channel.json_body, {"m.server": "test:443"}, @@ -75,4 +95,4 @@ class WellKnownTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/.well-known/matrix/server", shorthand=False ) - self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.code, 404)