From 35b1900f00b77e754efb909eae0a2f0c94e968cb Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 30 Nov 2021 10:53:54 +0100 Subject: Convert status codes to `HTTPStatus` in `tests.rest.admin` (#11455) --- tests/rest/admin/test_admin.py | 74 +++-- tests/rest/admin/test_background_updates.py | 2 +- tests/rest/admin/test_device.py | 101 ++++--- tests/rest/admin/test_event_reports.py | 122 +++++--- tests/rest/admin/test_media.py | 191 ++++++++---- tests/rest/admin/test_registration_tokens.py | 211 ++++++++++--- tests/rest/admin/test_room.py | 146 ++++----- tests/rest/admin/test_server_notice.py | 45 +-- tests/rest/admin/test_statistics.py | 102 +++++-- tests/rest/admin/test_user.py | 429 ++++++++++++++------------- tests/rest/admin/test_username_available.py | 20 +- 11 files changed, 886 insertions(+), 557 deletions(-) (limited to 'tests') diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index af849bd471..3adadcb46b 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os import urllib.parse +from http import HTTPStatus from unittest.mock import Mock from twisted.internet.defer import Deferred @@ -41,7 +41,7 @@ class VersionTestCase(unittest.HomeserverTestCase): def test_version_string(self): channel = self.make_request("GET", self.url, shorthand=False) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual( {"server_version", "python_version"}, set(channel.json_body.keys()) ) @@ -70,11 +70,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): content={"localpart": "test"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) group_id = channel.json_body["group_id"] - self._check_group(group_id, expect_code=200) + self._check_group(group_id, expect_code=HTTPStatus.OK) # Invite/join another user @@ -82,13 +82,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={} ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) url = "/groups/%s/self/accept_invite" % (group_id,) channel = self.make_request( "PUT", url.encode("ascii"), access_token=self.other_user_token, content={} ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check other user knows they're in the group self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -103,10 +103,10 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): content={"localpart": "test"}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - # Check group returns 404 - self._check_group(group_id, expect_code=404) + # Check group returns HTTPStatus.NOT_FOUND + self._check_group(group_id, expect_code=HTTPStatus.NOT_FOUND) # Check users don't think they're in the group self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -122,15 +122,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "GET", url.encode("ascii"), access_token=self.admin_user_tok ) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token)""" channel = self.make_request("GET", b"/joined_groups", access_token=access_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) return channel.json_body["groups"] @@ -210,10 +208,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Should be quarantined self.assertEqual( - 404, - int(channel.code), + HTTPStatus.NOT_FOUND, + channel.code, msg=( - "Expected to receive a 404 on accessing quarantined media: %s" + "Expected to receive a HTTPStatus.NOT_FOUND on accessing quarantined media: %s" % server_and_media_id ), ) @@ -232,8 +230,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Expect a forbidden error self.assertEqual( - 403, - int(channel.result["code"]), + HTTPStatus.FORBIDDEN, + channel.code, msg="Expected forbidden on quarantining media as a non-admin", ) @@ -247,8 +245,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Expect a forbidden error self.assertEqual( - 403, - int(channel.result["code"]), + HTTPStatus.FORBIDDEN, + channel.code, msg="Expected forbidden on quarantining media as a non-admin", ) @@ -279,7 +277,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): ) # Should be successful - self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code) # Quarantine the media url = "/_synapse/admin/v1/media/quarantine/%s/%s" % ( @@ -292,7 +290,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Attempt to access the media self._ensure_quarantined(admin_user_tok, server_name_and_media_id) @@ -348,11 +346,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual( - json.loads(channel.result["body"].decode("utf-8")), - {"num_quarantined": 2}, - "Expected 2 quarantined items", + channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" ) # Convert mxc URLs to server/media_id strings @@ -396,11 +392,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual( - json.loads(channel.result["body"].decode("utf-8")), - {"num_quarantined": 2}, - "Expected 2 quarantined items", + channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items" ) # Attempt to access each piece of media @@ -432,7 +426,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),) channel = self.make_request("POST", url, access_token=admin_user_tok) self.pump(1.0) - self.assertEqual(200, int(channel.code), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Quarantine all media by this user url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote( @@ -444,11 +438,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): access_token=admin_user_tok, ) self.pump(1.0) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual( - json.loads(channel.result["body"].decode("utf-8")), - {"num_quarantined": 1}, - "Expected 1 quarantined item", + channel.json_body, {"num_quarantined": 1}, "Expected 1 quarantined item" ) # Attempt to access each piece of media, the first should fail, the @@ -467,10 +459,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Shouldn't be quarantined self.assertEqual( - 200, - int(channel.code), + HTTPStatus.OK, + channel.code, msg=( - "Expected to receive a 200 on accessing not-quarantined media: %s" + "Expected to receive a HTTPStatus.OK on accessing not-quarantined media: %s" % server_and_media_id_2 ), ) @@ -499,7 +491,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): def test_purge_history(self): """ Simple test of purge history API. - Test only that is is possible to call, get status 200 and purge_id. + Test only that is is possible to call, get status HTTPStatus.OK and purge_id. """ channel = self.make_request( @@ -509,7 +501,7 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIn("purge_id", channel.json_body) purge_id = channel.json_body["purge_id"] @@ -520,5 +512,5 @@ class PurgeHistoryTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 62f242baf6..a5423af652 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -46,7 +46,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) def test_requester_is_no_admin(self, method: str, url: str): """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ self.register_user("user", "pass", admin=False) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index a3679be205..baff057c56 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -13,6 +13,7 @@ # limitations under the License. import urllib.parse +from http import HTTPStatus from parameterized import parameterized @@ -53,7 +54,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) @@ -67,13 +72,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) def test_user_does_not_exist(self, method: str): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = ( "/_synapse/admin/v2/users/@unknown_person:test/devices/%s" @@ -86,13 +95,13 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) def test_user_is_not_local(self, method: str): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = ( "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s" @@ -105,12 +114,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) def test_unknown_device(self): """ - Tests that a lookup for a device that does not exist returns either 404 or 200. + Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. """ url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( self.other_user @@ -122,7 +131,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) channel = self.make_request( @@ -131,7 +140,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) channel = self.make_request( "DELETE", @@ -139,8 +148,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # Delete unknown device returns status 200 - self.assertEqual(200, channel.code, msg=channel.json_body) + # Delete unknown device returns status HTTPStatus.OK + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_update_device_too_long_display_name(self): """ @@ -167,7 +176,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content=update, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. @@ -177,12 +186,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) def test_update_no_display_name(self): """ - Tests that a update for a device without JSON returns a 200 + Tests that a update for a device without JSON returns a HTTPStatus.OK """ # Set iniital display name. update = {"display_name": "new display"} @@ -198,7 +207,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Ensure the display name was not updated. channel = self.make_request( @@ -207,7 +216,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) def test_update_display_name(self): @@ -222,7 +231,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): content={"display_name": "new displayname"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check new display_name channel = self.make_request( @@ -231,7 +240,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) def test_get_device(self): @@ -244,7 +253,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) # Check that all fields are available self.assertIn("user_id", channel.json_body) @@ -269,7 +278,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Ensure that the number of devices is decreased res = self.get_success(self.handler.get_devices_by_user(self.other_user)) @@ -299,7 +308,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -314,12 +327,16 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" channel = self.make_request( @@ -328,12 +345,12 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" @@ -343,7 +360,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) def test_user_has_no_devices(self): @@ -359,7 +376,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["devices"])) @@ -379,7 +396,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) @@ -417,7 +434,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -432,12 +453,16 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" channel = self.make_request( @@ -446,12 +471,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" @@ -461,12 +486,12 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) def test_unknown_devices(self): """ - Tests that a remove of a device that does not exist returns 200. + Tests that a remove of a device that does not exist returns HTTPStatus.OK. """ channel = self.make_request( "POST", @@ -475,8 +500,8 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": ["unknown_device1", "unknown_device2"]}, ) - # Delete unknown devices returns status 200 - self.assertEqual(200, channel.code, msg=channel.json_body) + # Delete unknown devices returns status HTTPStatus.OK + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_delete_devices(self): """ @@ -505,7 +530,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): content={"devices": device_ids}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) res = self.get_success(self.handler.get_devices_by_user(self.other_user)) self.assertEqual(0, len(res)) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index e9ef89731f..a9c46ec62d 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +from http import HTTPStatus import synapse.rest.admin from synapse.api.errors import Codes @@ -76,12 +76,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( @@ -90,7 +94,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_default_success(self): @@ -104,7 +112,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -121,7 +129,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -138,7 +146,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -155,7 +163,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["event_reports"]), 10) @@ -172,7 +180,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -192,7 +200,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertNotIn("next_token", channel.json_body) @@ -212,7 +220,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["event_reports"]), 5) self.assertNotIn("next_token", channel.json_body) @@ -234,7 +242,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -252,7 +260,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) report = 1 @@ -265,7 +273,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): def test_invalid_search_order(self): """ - Testing that a invalid search order returns a 400 + Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST """ channel = self.make_request( @@ -274,13 +282,17 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual("Unknown direction: bar", channel.json_body["error"]) def test_limit_is_negative(self): """ - Testing that a negative limit parameter returns a 400 + Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST """ channel = self.make_request( @@ -289,12 +301,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_from_is_negative(self): """ - Testing that a negative from parameter returns a 400 + Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST """ channel = self.make_request( @@ -303,7 +319,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self): @@ -319,7 +339,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -332,7 +352,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 20) self.assertNotIn("next_token", channel.json_body) @@ -345,7 +365,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -359,7 +379,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -372,10 +392,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - json.dumps({"score": -100, "reason": "this makes me sad"}), + {"score": -100, "reason": "this makes me sad"}, access_token=user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def _create_event_and_report_without_parameters(self, room_id, user_tok): """Create and report an event, but omit reason and score""" @@ -385,10 +405,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - json.dumps({}), + {}, access_token=user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def _check_fields(self, content): """Checks that all attributes are present in an event report""" @@ -439,12 +459,16 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( @@ -453,7 +477,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_default_success(self): @@ -467,12 +495,12 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) def test_invalid_report_id(self): """ - Testing that an invalid `report_id` returns a 400. + Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST. """ # `report_id` is negative @@ -482,7 +510,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -496,7 +528,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -510,7 +546,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "The report_id parameter must be a string representing a positive integer.", @@ -519,7 +559,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): def test_report_id_not_found(self): """ - Testing that a not existing `report_id` returns a 404. + Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND. """ channel = self.make_request( @@ -528,7 +568,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual("Event report not found", channel.json_body["error"]) @@ -540,10 +584,10 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "rooms/%s/report/%s" % (room_id, event_id), - json.dumps({"score": -100, "reason": "this makes me sad"}), + {"score": -100, "reason": "this makes me sad"}, access_token=user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def _check_fields(self, content): """Checks that all attributes are present in a event report""" diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index db0e78c039..6618279dd1 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os +from http import HTTPStatus from parameterized import parameterized @@ -56,7 +56,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): channel = self.make_request("DELETE", url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -74,12 +78,16 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_media_does_not_exist(self): """ - Tests that a lookup for a media that does not exist returns a 404 + Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") @@ -89,12 +97,12 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_media_is_not_local(self): """ - Tests that a lookup for a media that is not a local returns a 400 + Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") @@ -104,7 +112,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) def test_delete_media(self): @@ -117,7 +125,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 + upload_resource, + SMALL_PNG, + tok=self.admin_user_tok, + expect_code=HTTPStatus.OK, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -137,10 +148,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Should be successful self.assertEqual( - 200, + HTTPStatus.OK, channel.code, msg=( - "Expected to receive a 200 on accessing media: %s" % server_and_media_id + "Expected to receive a HTTPStatus.OK on accessing media: %s" + % server_and_media_id ), ) @@ -157,7 +169,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -174,10 +186,10 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual( - 404, + HTTPStatus.NOT_FOUND, channel.code, msg=( - "Expected to receive a 404 on accessing deleted media: %s" + "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" % server_and_media_id ), ) @@ -216,7 +228,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -232,12 +248,16 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_media_is_not_local(self): """ - Tests that a lookup for media that is not local returns a 400 + Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" @@ -247,7 +267,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) def test_missing_parameter(self): @@ -260,7 +280,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( "Missing integer query parameter 'before_ts'", channel.json_body["error"] @@ -276,7 +300,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -289,7 +317,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " @@ -303,7 +335,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter size_gt must be a string representing a positive integer.", @@ -316,7 +352,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", @@ -345,7 +385,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( media_id, @@ -370,7 +410,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -382,7 +422,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -406,7 +446,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -417,7 +457,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -439,10 +479,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "/profile/%s/avatar_url" % (self.admin_user,), - content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), + content={"avatar_url": "mxc://%s" % (server_and_media_id,)}, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -450,7 +490,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -461,7 +501,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -484,10 +524,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.avatar" % (room_id,), - content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), + content={"url": "mxc://%s" % (server_and_media_id,)}, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) now_ms = self.clock.time_msec() channel = self.make_request( @@ -495,7 +535,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self._access_media(server_and_media_id) @@ -506,7 +546,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( server_and_media_id.split("/")[1], @@ -523,7 +563,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 + upload_resource, + SMALL_PNG, + tok=self.admin_user_tok, + expect_code=HTTPStatus.OK, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -554,10 +597,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): if expect_success: self.assertEqual( - 200, + HTTPStatus.OK, channel.code, msg=( - "Expected to receive a 200 on accessing media: %s" + "Expected to receive a HTTPStatus.OK on accessing media: %s" % server_and_media_id ), ) @@ -565,10 +608,10 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertTrue(os.path.exists(local_path)) else: self.assertEqual( - 404, + HTTPStatus.NOT_FOUND, channel.code, msg=( - "Expected to receive a 404 on accessing deleted media: %s" + "Expected to receive a HTTPStatus.NOT_FOUND on accessing deleted media: %s" % (server_and_media_id) ), ) @@ -597,7 +640,10 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 + upload_resource, + SMALL_PNG, + tok=self.admin_user_tok, + expect_code=HTTPStatus.OK, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -617,7 +663,11 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): b"{}", ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["quarantine", "unquarantine"]) @@ -634,7 +684,11 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_quarantine_media(self): @@ -652,7 +706,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -665,7 +719,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -690,7 +744,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) # verify that is not in quarantine @@ -718,7 +772,10 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200 + upload_resource, + SMALL_PNG, + tok=self.admin_user_tok, + expect_code=HTTPStatus.OK, ) # Extract media ID from the response server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' @@ -734,7 +791,11 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url % (action, self.media_id), b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["protect", "unprotect"]) @@ -751,7 +812,11 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_protect_media(self): @@ -769,7 +834,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -782,7 +847,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) @@ -816,7 +881,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -832,7 +901,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -845,7 +918,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts must be a positive integer.", @@ -858,7 +935,11 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( "Query parameter before_ts you provided is from the year 1970. " diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 9bac423ae0..63087955f2 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -14,6 +14,7 @@ import random import string +from http import HTTPStatus import synapse.rest.admin from synapse.api.errors import Codes @@ -63,7 +64,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_create_no_auth(self): """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_create_requester_not_admin(self): @@ -74,7 +79,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_create_using_defaults(self): @@ -86,7 +95,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -110,7 +119,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) @@ -131,7 +140,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 16) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -149,7 +158,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_token_invalid_chars(self): @@ -165,7 +178,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_token_already_exists(self): @@ -180,7 +197,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) + self.assertEqual(HTTPStatus.OK, channel1.code, msg=channel1.json_body) channel2 = self.make_request( "POST", @@ -188,7 +205,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): data, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) def test_create_unable_to_generate_token(self): @@ -220,7 +237,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(500, channel.code, msg=channel.json_body) def test_create_uses_allowed(self): """Check you can only create a token with good values for uses_allowed.""" @@ -231,7 +248,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) # Should fail with negative integer @@ -241,7 +258,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -251,7 +272,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_expiry_time(self): @@ -263,7 +288,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() - 10000}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with float @@ -273,7 +302,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": self.clock.time_msec() + 1000000.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_create_length(self): @@ -285,7 +318,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 64}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["token"]), 64) # Should fail with 0 @@ -295,7 +328,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -305,7 +342,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": -5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a float @@ -315,7 +356,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 8.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with 65 @@ -325,7 +370,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"length": 65}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # UPDATING @@ -337,7 +386,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_update_requester_not_admin(self): @@ -348,7 +401,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_update_non_existent(self): @@ -360,7 +417,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_update_uses_allowed(self): @@ -375,7 +436,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertIsNone(channel.json_body["expiry_time"]) @@ -386,7 +447,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 0}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 0) self.assertIsNone(channel.json_body["expiry_time"]) @@ -397,7 +458,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": None}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -408,7 +469,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": 1.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail with a negative integer @@ -418,7 +483,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"uses_allowed": -5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_update_expiry_time(self): @@ -434,7 +503,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -445,7 +514,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": None}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIsNone(channel.json_body["expiry_time"]) self.assertIsNone(channel.json_body["uses_allowed"]) @@ -457,7 +526,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": past_time}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # Should fail a float @@ -467,7 +540,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {"expiry_time": new_expiry_time + 0.5}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) def test_update_both(self): @@ -488,7 +565,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) @@ -509,7 +586,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) # DELETING @@ -521,7 +602,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_delete_requester_not_admin(self): @@ -532,7 +617,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_delete_non_existent(self): @@ -544,7 +633,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_delete(self): @@ -559,7 +652,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # GETTING ONE @@ -570,7 +663,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url + "/1234", # Token doesn't exist but that doesn't matter {}, ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_get_requester_not_admin(self): @@ -581,7 +678,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_get_non_existent(self): @@ -593,7 +694,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.NOT_FOUND, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) def test_get(self): @@ -608,7 +713,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["token"], token) self.assertIsNone(channel.json_body["uses_allowed"]) self.assertIsNone(channel.json_body["expiry_time"]) @@ -620,7 +725,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_list_no_auth(self): """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_list_requester_not_admin(self): @@ -631,7 +740,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_list_all(self): @@ -646,7 +759,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 1) token_info = channel.json_body["registration_tokens"][0] self.assertEqual(token_info["token"], token) @@ -664,7 +777,11 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) def _test_list_query_parameter(self, valid: str): """Helper used to test both valid=true and valid=false.""" @@ -696,7 +813,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(len(channel.json_body["registration_tokens"]), 2) token_info_1 = channel.json_body["registration_tokens"][0] token_info_2 = channel.json_body["registration_tokens"][1] diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 07077aff78..56b7a438b6 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -66,7 +66,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( @@ -76,12 +76,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.other_user_tok, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self): """ - Check that unknown rooms/server return error 404. + Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. """ url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test" @@ -92,12 +92,12 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_room_is_not_valid(self): """ - Check that invalid room names, return an error 400. + Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. """ url = "/_synapse/admin/v1/rooms/%s" % "invalidroom" @@ -108,7 +108,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -127,7 +127,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -146,7 +146,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -165,7 +165,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self): @@ -181,7 +181,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self): @@ -207,7 +207,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -240,7 +240,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -274,7 +274,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -305,9 +305,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) # The room is now blocked. - self.assertEqual( - HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self._is_blocked(room_id) def test_shutdown_room_consent(self): @@ -327,7 +325,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert that the user is getting consent error self.helper.send( - self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + self.room_id, + body="foo", + tok=self.other_user_tok, + expect_code=HTTPStatus.FORBIDDEN, ) # Test that room is not purged @@ -345,7 +346,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -374,7 +375,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -391,7 +392,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -406,7 +407,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=403) + self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) def _is_blocked(self, room_id, expect=True): """Assert that the room is blocked or not""" @@ -502,7 +503,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) def test_requester_is_no_admin(self, method: str, url: str): """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( @@ -524,7 +525,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) def test_room_does_not_exist(self, method: str, url: str): """ - Check that unknown rooms/server return error 404. + Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. """ channel = self.make_request( @@ -545,7 +546,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) def test_room_is_not_valid(self, method: str, url: str): """ - Check that invalid room names, return an error 400. + Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. """ channel = self.make_request( @@ -854,7 +855,10 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): # Assert that the user is getting consent error self.helper.send( - self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + self.room_id, + body="foo", + tok=self.other_user_tok, + expect_code=HTTPStatus.FORBIDDEN, ) # Test that room is not purged @@ -951,7 +955,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self._has_no_members(self.room_id) # Assert we can no longer peek into the room - self._assert_peek(self.room_id, expect_code=403) + self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN) def _is_blocked(self, room_id: str, expect: bool = True) -> None: """Assert that the room is blocked or not""" @@ -1094,7 +1098,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) # Check request completed successfully - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -1178,7 +1182,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -1218,7 +1222,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_correct_room_attributes(self): """Test the correct attributes for a room are returned""" @@ -1241,7 +1245,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1273,7 +1277,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1328,7 +1332,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1467,7 +1471,7 @@ class RoomTestCase(unittest.HomeserverTestCase): def _search_test( expected_room_id: Optional[str], search_term: str, - expected_http_code: int = 200, + expected_http_code: int = HTTPStatus.OK, ): """Search for a room and check that the returned room's id is a match @@ -1485,7 +1489,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != 200: + if expected_http_code != HTTPStatus.OK: return # Check that rooms were returned @@ -1528,7 +1532,7 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(None, "foo") _search_test(None, "bar") - _search_test(None, "", expected_http_code=400) + _search_test(None, "", expected_http_code=HTTPStatus.BAD_REQUEST) # Test that the whole room id returns the room _search_test(room_id_1, room_id_1) @@ -1565,7 +1569,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) @@ -1598,7 +1602,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIn("room_id", channel.json_body) self.assertIn("name", channel.json_body) @@ -1630,7 +1634,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["joined_local_devices"]) # Have another user join the room @@ -1644,7 +1648,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(2, channel.json_body["joined_local_devices"]) # leave room @@ -1656,7 +1660,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) def test_room_members(self): @@ -1687,7 +1691,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"] @@ -1700,7 +1704,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertCountEqual( ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"] @@ -1718,7 +1722,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertIn("state", channel.json_body) # testing that the state events match is painful and not done here. We assume that # the create_room already does the right thing, so no need to verify that we got @@ -1733,7 +1737,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1776,7 +1780,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ body = json.dumps({"user_id": self.second_user_id}) @@ -1787,7 +1791,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.second_tok, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -1803,12 +1807,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_local_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ body = json.dumps({"user_id": "@unknown:test"}) @@ -1819,7 +1823,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_remote_user(self): @@ -1835,7 +1839,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], @@ -1843,7 +1847,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): def test_room_does_not_exist(self): """ - Check that unknown rooms/server return error 404. + Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. """ body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/!unknown:test" @@ -1855,12 +1859,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) def test_room_is_not_valid(self): """ - Check that invalid room names, return an error 400. + Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. """ body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/invalidroom" @@ -1872,7 +1876,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], @@ -1891,7 +1895,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1901,7 +1905,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, channel.code, msg=channel.json_body) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self): @@ -1922,7 +1926,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_join_private_room_if_member(self): @@ -1950,7 +1954,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(200, channel.code, msg=channel.json_body) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1964,7 +1968,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1974,7 +1978,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, channel.code, msg=channel.json_body) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self): @@ -1995,7 +1999,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -2005,7 +2009,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, channel.code, msg=channel.json_body) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self): @@ -2039,7 +2043,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals(403, channel.code, msg=channel.json_body) + self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self): @@ -2069,7 +2073,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(200, channel.code, msg=channel.json_body) + self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEquals( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -2128,7 +2132,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -2155,7 +2159,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -2181,7 +2185,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -2215,11 +2219,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - # We expect this to fail with a 400 as there are no room admins. + # We expect this to fail with a HTTPStatus.BAD_REQUEST as there are no room admins. # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", @@ -2249,7 +2253,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): @parameterized.expand([("PUT",), ("GET",)]) def test_requester_is_no_admin(self, method: str): - """If the user is not a server admin, an error 403 is returned.""" + """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.""" channel = self.make_request( method, @@ -2263,7 +2267,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): @parameterized.expand([("PUT",), ("GET",)]) def test_room_is_not_valid(self, method: str): - """Check that invalid room names, return an error 400.""" + """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.""" channel = self.make_request( method, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index fbceba3254..0b9da4c732 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from typing import List import synapse.rest.admin @@ -52,7 +53,11 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): """Try to send a server notice without authentication.""" channel = self.make_request("POST", self.url) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -63,12 +68,16 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) def test_user_does_not_exist(self): - """Tests that a lookup for a user that does not exist returns a 404""" + """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" channel = self.make_request( "POST", self.url, @@ -76,13 +85,13 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": "@unknown_person:test", "content": ""}, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ channel = self.make_request( "POST", @@ -94,7 +103,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "Server notices can only be sent to local users", channel.json_body["error"] ) @@ -110,7 +119,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) # no content @@ -121,7 +130,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # no body @@ -132,7 +141,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": ""}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'body' not in content", channel.json_body["error"]) @@ -144,7 +153,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): content={"user_id": self.other_user, "content": {"body": ""}}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"]) @@ -160,7 +169,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( "Server notices are not enabled on this server", channel.json_body["error"] @@ -185,7 +194,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -216,7 +225,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has no new invites or memberships self._check_invite_and_join_status(self.other_user, 0, 1) @@ -250,7 +259,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -293,7 +302,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -333,7 +342,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg one"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -382,7 +391,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): "content": {"msgtype": "m.text", "body": "test msg two"}, }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # user has one invite invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) @@ -440,7 +449,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "/_matrix/client/r0/sync", access_token=token ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) # Get the messages room = channel.json_body["rooms"]["join"][room_id] diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index ece89a65ac..43d8ca032b 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +from http import HTTPStatus from typing import Any, Dict, List, Optional import synapse.rest.admin @@ -47,21 +47,29 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.UNAUTHORIZED, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): """ - If the user is not a server admin, an error 403 is returned. + If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ channel = self.make_request( "GET", self.url, - json.dumps({}), + {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.FORBIDDEN, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -75,7 +83,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -85,7 +97,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative limit @@ -95,7 +111,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from_ts @@ -105,7 +125,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative until_ts @@ -115,7 +139,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # until_ts smaller from_ts @@ -125,7 +153,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # empty search term @@ -135,7 +167,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid search order @@ -145,7 +181,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_limit(self): @@ -160,7 +200,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -178,7 +218,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -196,7 +236,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["users"]), 10) @@ -218,7 +258,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -231,7 +271,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -244,7 +284,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -257,7 +297,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -274,7 +314,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["users"])) @@ -371,7 +411,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media @@ -381,7 +421,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) self._create_media(self.other_user_tok, 3) @@ -396,7 +436,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media until `ts2` and earlier @@ -405,7 +445,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) def test_search_term(self): @@ -417,7 +457,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 20) # filter user 1 and 10-19 by `user_id` @@ -426,7 +466,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foo_user_1", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 11) # filter on this user in `displayname` @@ -435,7 +475,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=bar_user_10", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["displayname"], "bar_user_10") self.assertEqual(channel.json_body["total"], 1) @@ -445,7 +485,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url + "?search_term=foobar", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) def _create_users_with_media(self, number_users: int, media_per_user: int): @@ -471,7 +511,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): for _ in range(number_media): # Upload some media into the room self.helper.upload_media( - upload_resource, SMALL_PNG, tok=user_token, expect_code=200 + upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK ) def _check_fields(self, content: List[Dict[str, Any]]): @@ -505,7 +545,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["user_id"] for row in channel.json_body["users"]] diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 5011e54563..03aa689ace 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -17,6 +17,7 @@ import hmac import os import urllib.parse from binascii import unhexlify +from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock, patch @@ -74,7 +75,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -106,7 +107,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -114,7 +115,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): @@ -137,7 +138,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self): @@ -164,7 +165,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self): @@ -187,13 +188,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): @@ -214,7 +215,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be an empty body present channel = self.make_request("POST", self.url, {}) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -224,28 +225,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Must be present channel = self.make_request("POST", self.url, {"nonce": nonce()}) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -256,28 +257,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = {"nonce": nonce(), "username": "a"} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = {"nonce": nonce(), "username": "a", "password": 1234} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = {"nonce": nonce(), "username": "a", "password": "A" * 1000} channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -293,7 +294,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self): @@ -318,11 +319,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -342,11 +343,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -366,11 +367,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) # set displayname channel = self.make_request("GET", self.url) @@ -389,11 +390,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -437,7 +438,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): } channel = self.make_request("POST", self.url, body) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -461,7 +462,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -473,7 +474,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=other_user_token) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self): @@ -489,7 +490,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) @@ -503,7 +504,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): expected_user_id: Optional[str], search_term: str, search_field: Optional[str] = "name", - expected_http_code: Optional[int] = 200, + expected_http_code: Optional[int] = HTTPStatus.OK, ): """Search for a user and check that the returned user's id is a match @@ -525,7 +526,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) - if expected_http_code != 200: + if expected_http_code != HTTPStatus.OK: return # Check that users were returned @@ -586,7 +587,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -596,7 +597,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid guests @@ -606,7 +607,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid deactivated @@ -616,7 +617,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # unkown order_by @@ -626,7 +627,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -636,7 +637,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) def test_limit(self): @@ -654,7 +655,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -675,7 +676,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -696,7 +697,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["users"]), 10) @@ -719,7 +720,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -732,7 +733,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -745,7 +746,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -759,7 +760,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -862,7 +863,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) returned_order = [row["name"] for row in channel.json_body["users"]] @@ -936,7 +937,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -947,7 +948,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", url, access_token=self.other_user_token) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -957,12 +958,12 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): """ - Tests that deactivation for a user that does not exist returns a 404 + Tests that deactivation for a user that does not exist returns a HTTPStatus.NOT_FOUND """ channel = self.make_request( @@ -971,7 +972,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_erase_is_not_bool(self): @@ -986,18 +987,18 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that deactivation for a user that is not a local returns a 400 + Tests that deactivation for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain" channel = self.make_request("POST", url, access_token=self.admin_user_tok) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only deactivate local users", channel.json_body["error"]) def test_deactivate_user_erase_true(self): @@ -1012,7 +1013,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1027,7 +1028,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1036,7 +1037,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1057,7 +1058,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1072,7 +1073,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": False}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1081,7 +1082,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1111,7 +1112,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -1126,7 +1127,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -1135,7 +1136,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(True, channel.json_body["deactivated"]) self.assertEqual(0, len(channel.json_body["threepids"])) @@ -1195,7 +1196,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( @@ -1205,12 +1206,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=b"{}", ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual("You are not a server admin", channel.json_body["error"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ channel = self.make_request( @@ -1219,7 +1220,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -1234,7 +1235,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"admin": "not_bool"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) # deactivated not bool @@ -1244,7 +1245,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": "not_bool"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not str @@ -1254,7 +1255,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": True}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # password not length @@ -1264,7 +1265,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"password": "x" * 513}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # user_type not valid @@ -1274,7 +1275,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"user_type": "new type"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # external_ids not valid @@ -1286,7 +1287,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} }, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1295,7 +1296,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) # threepids not valid @@ -1305,7 +1306,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) channel = self.make_request( @@ -1314,7 +1315,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": {"address": "value"}}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_get_user(self): @@ -1327,7 +1328,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) @@ -1370,7 +1371,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1433,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1461,9 +1462,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # before limit of monthly active users is reached channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok) - if channel.code != 200: + if channel.code != HTTPStatus.OK: raise HttpResponseException( - channel.code, channel.result["reason"], channel.result["body"] + channel.code, channel.result["reason"], channel.json_body ) # Set monthly active users to the limit @@ -1625,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "hahaha"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) def test_set_displayname(self): @@ -1641,7 +1642,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "foobar"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1652,7 +1653,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("foobar", channel.json_body["displayname"]) @@ -1674,7 +1675,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1700,7 +1701,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1716,7 +1717,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1732,7 +1733,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"threepids": []}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1759,7 +1760,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1778,7 +1779,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1800,7 +1801,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # other user has this two threepids - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["threepids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1819,7 +1820,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): url_first_user, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) @@ -1848,7 +1849,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) # result does not always have the same sort order, therefore it becomes sorted @@ -1880,7 +1881,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1899,7 +1900,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1918,7 +1919,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"external_ids": []}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["external_ids"])) @@ -1947,7 +1948,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -1973,7 +1974,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): }, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2005,7 +2006,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # must fail - self.assertEqual(409, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CONFLICT, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("External id is already in use.", channel.json_body["error"]) @@ -2016,7 +2017,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2034,7 +2035,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(first_user, channel.json_body["name"]) self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertEqual( @@ -2065,7 +2066,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) @@ -2080,7 +2081,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2096,7 +2097,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2123,7 +2124,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"deactivated": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) @@ -2139,7 +2140,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"displayname": "Foobar"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) @@ -2163,7 +2164,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) # Reactivate the user. channel = self.make_request( @@ -2172,7 +2173,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNotNone(channel.json_body["password_hash"]) @@ -2194,7 +2195,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2204,7 +2205,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2226,7 +2227,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Reactivate the user without a password. @@ -2236,7 +2237,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": False}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) @@ -2255,7 +2256,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"admin": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2266,7 +2267,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) @@ -2283,7 +2284,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": UserTypes.SUPPORT}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2294,7 +2295,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) @@ -2306,7 +2307,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"user_type": None}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2317,7 +2318,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertIsNone(channel.json_body["user_type"]) @@ -2347,7 +2348,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual(0, channel.json_body["deactivated"]) @@ -2360,7 +2361,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "deactivated": "false"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) # Check user is not deactivated channel = self.make_request( @@ -2369,7 +2370,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) @@ -2394,7 +2395,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"deactivated": True}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) self.assertIsNone(channel.json_body["password_hash"]) self._is_erased(user_id, False) @@ -2445,7 +2446,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2460,7 +2461,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -2474,7 +2475,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2490,7 +2491,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2506,7 +2507,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["joined_rooms"])) @@ -2527,7 +2528,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) @@ -2574,7 +2575,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"]) @@ -2603,7 +2604,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2618,12 +2619,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" channel = self.make_request( @@ -2632,12 +2633,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" @@ -2647,7 +2648,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self): @@ -2662,7 +2663,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) # Register the pusher @@ -2693,7 +2694,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) for p in channel.json_body["pushers"]: @@ -2732,7 +2733,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """Try to list media of an user without authentication.""" channel = self.make_request(method, self.url, {}) - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) @@ -2746,12 +2747,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_does_not_exist(self, method: str): - """Tests that a lookup for a user that does not exist returns a 404""" + """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( method, @@ -2759,12 +2760,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "DELETE"]) def test_user_is_not_local(self, method: str): - """Tests that a lookup for a user that is not a local returns a 400""" + """Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( @@ -2773,7 +2774,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_limit_GET(self): @@ -2789,7 +2790,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) @@ -2808,7 +2809,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 5) self.assertEqual(len(channel.json_body["deleted_media"]), 5) @@ -2825,7 +2826,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -2844,7 +2845,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 15) self.assertEqual(len(channel.json_body["deleted_media"]), 15) @@ -2861,7 +2862,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) @@ -2880,7 +2881,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 10) self.assertEqual(len(channel.json_body["deleted_media"]), 10) @@ -2894,7 +2895,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -2904,7 +2905,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative limit @@ -2914,7 +2915,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -2924,7 +2925,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self): @@ -2947,7 +2948,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2960,7 +2961,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2973,7 +2974,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -2987,7 +2988,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -3004,7 +3005,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) @@ -3019,7 +3020,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["deleted_media"])) @@ -3036,7 +3037,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["media"])) self.assertNotIn("next_token", channel.json_body) @@ -3062,7 +3063,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) self.assertEqual(number_media, len(channel.json_body["deleted_media"])) self.assertCountEqual(channel.json_body["deleted_media"], media_ids) @@ -3207,7 +3208,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, user_token, filename, expect_code=200 + upload_resource, image_data, user_token, filename, expect_code=HTTPStatus.OK ) # Extract media ID from the response @@ -3225,10 +3226,10 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - 200, + HTTPStatus.OK, channel.code, msg=( - f"Expected to receive a 200 on accessing media: {server_and_media_id}" + f"Expected to receive a HTTPStatus.OK on accessing media: {server_and_media_id}" ), ) @@ -3274,7 +3275,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_media_list)) returned_order = [row["media_id"] for row in channel.json_body["media"]] @@ -3310,14 +3311,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self): """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self): @@ -3326,7 +3327,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): "POST", self.url, b"{}", access_token=self.other_user_tok ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) def test_send_event(self): """Test that sending event as a user works.""" @@ -3351,7 +3352,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -3363,21 +3364,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_user_logout_all(self): """Tests that the target user calling `/logout/all` does *not* expire @@ -3388,23 +3389,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) def test_admin_logout_all(self): """Tests that the admin user calling `/logout/all` does expire the @@ -3415,23 +3416,23 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3459,7 +3460,10 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Now unaccept it and check that we can't send an event self.get_success(self.store.user_set_consent_version(self.other_user, "0.0")) self.helper.send_event( - room_id, "com.example.test", tok=self.other_user_tok, expect_code=403 + room_id, + "com.example.test", + tok=self.other_user_tok, + expect_code=HTTPStatus.FORBIDDEN, ) # Login in as the user @@ -3477,7 +3481,10 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Trying to join as the other user should fail due to reaching MAU limit. self.helper.join( - room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403 + room_id, + user=self.other_user, + tok=self.other_user_tok, + expect_code=HTTPStatus.FORBIDDEN, ) # Logging in as the other user and joining a room should work, even @@ -3512,7 +3519,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3527,12 +3534,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user2_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = self.url_prefix % "@unknown_person:unknown_domain" @@ -3541,7 +3548,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): url, access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) def test_get_whois_admin(self): @@ -3553,7 +3560,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3568,7 +3575,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.url, access_token=other_user_token, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -3598,7 +3605,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request(method, self.url) - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) @@ -3609,18 +3616,18 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): other_user_token = self.login("user", "pass") channel = self.make_request(method, self.url, access_token=other_user_token) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["POST", "DELETE"]) def test_user_is_not_local(self, method: str): """ - Tests that shadow-banning for a user that is not a local returns a 400 + Tests that shadow-banning for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" channel = self.make_request(method, url, access_token=self.admin_user_tok) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) def test_success(self): """ @@ -3632,7 +3639,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.assertFalse(result.shadow_banned) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is shadow-banned (and the cache was cleared). @@ -3643,7 +3650,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, access_token=self.admin_user_tok ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual({}, channel.json_body) # Ensure the user is no longer shadow-banned (and the cache was cleared). @@ -3677,7 +3684,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): """ channel = self.make_request(method, self.url, b"{}") - self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) @@ -3693,13 +3700,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "POST", "DELETE"]) def test_user_does_not_exist(self, method: str): """ - Tests that a lookup for a user that does not exist returns a 404 + Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" @@ -3709,7 +3716,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand( @@ -3721,7 +3728,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) def test_user_is_not_local(self, method: str, error_msg: str): """ - Tests that a lookup for a user that is not a local returns a 400 + Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ url = ( "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit" @@ -3733,7 +3740,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(error_msg, channel.json_body["error"]) def test_invalid_parameter(self): @@ -3748,7 +3755,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": "string"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3759,7 +3766,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": -1}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3770,7 +3777,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": "string"}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3781,7 +3788,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": -1}, ) - self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self): @@ -3806,7 +3813,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3820,7 +3827,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3831,7 +3838,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3842,7 +3849,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3852,7 +3859,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3862,7 +3869,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3872,6 +3879,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 4e1c49c28b..7978626e71 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + import synapse.rest.admin from synapse.api.errors import Codes, SynapseError from synapse.rest.client import login @@ -33,30 +35,38 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): async def check_username(username): if username == "allowed": return True - raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "User ID already taken.", + errcode=Codes.USER_IN_USE, + ) handler = self.hs.get_registration_handler() handler.check_username = check_username def test_username_available(self): """ - The endpoint should return a 200 response if the username does not exist + The endpoint should return a HTTPStatus.OK response if the username does not exist """ url = "%s?username=%s" % (self.url, "allowed") channel = self.make_request("GET", url, None, self.admin_user_tok) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["available"]) def test_username_unavailable(self): """ - The endpoint should return a 200 response if the username does not exist + The endpoint should return a HTTPStatus.OK response if the username does not exist """ url = "%s?username=%s" % (self.url, "disallowed") channel = self.make_request("GET", url, None, self.admin_user_tok) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, + channel.code, + msg=channel.json_body, + ) self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") self.assertEqual(channel.json_body["error"], "User ID already taken.") -- cgit 1.5.1 From 28f5252c1f5e675aafc12e86b1237d2bedcd1a3c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 30 Nov 2021 08:23:53 -0500 Subject: Add missing copyright header. (#11460) --- changelog.d/11460.misc | 1 + tests/federation/transport/test_client.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) create mode 100644 changelog.d/11460.misc (limited to 'tests') diff --git a/changelog.d/11460.misc b/changelog.d/11460.misc new file mode 100644 index 0000000000..fc6bc82b36 --- /dev/null +++ b/changelog.d/11460.misc @@ -0,0 +1 @@ +Fix a bug introduced in 1.47.0 where `send_join` could fail due to an outdated `ijson` version. diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index 0b19159961..a7031a55f2 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -1,3 +1,17 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json from synapse.api.room_versions import RoomVersions -- cgit 1.5.1 From 432a174bc192740ac7a0a755009f6099b8363ad9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 30 Nov 2021 15:51:04 +0100 Subject: Remove unnecessary `json.dumps` from `tests.rest.admin` (#11461) The tests helpers automatically convert dictionaries to JSON payloads, no need to do it manually for each test. --- changelog.d/11461.misc | 1 + tests/rest/admin/test_room.py | 61 ++++++++++++++----------------------------- 2 files changed, 21 insertions(+), 41 deletions(-) create mode 100644 changelog.d/11461.misc (limited to 'tests') diff --git a/changelog.d/11461.misc b/changelog.d/11461.misc new file mode 100644 index 0000000000..92133f9eaa --- /dev/null +++ b/changelog.d/11461.misc @@ -0,0 +1 @@ +Remove unnecessary `json.dumps` from `tests.rest.admin`. \ No newline at end of file diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 56b7a438b6..681f9173ef 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import urllib.parse from http import HTTPStatus from typing import List, Optional @@ -118,12 +117,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ Tests that the user ID must be from local server but it does not have to exist. """ - body = json.dumps({"new_room_user_id": "@unknown:test"}) channel = self.make_request( "DELETE", self.url, - content=body, + content={"new_room_user_id": "@unknown:test"}, access_token=self.admin_user_tok, ) @@ -137,12 +135,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ Check that only local users can create new room to move members. """ - body = json.dumps({"new_room_user_id": "@not:exist.bla"}) channel = self.make_request( "DELETE", self.url, - content=body, + content={"new_room_user_id": "@not:exist.bla"}, access_token=self.admin_user_tok, ) @@ -156,12 +153,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ If parameter `block` is not boolean, return an error """ - body = json.dumps({"block": "NotBool"}) channel = self.make_request( "DELETE", self.url, - content=body, + content={"block": "NotBool"}, access_token=self.admin_user_tok, ) @@ -172,12 +168,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ If parameter `purge` is not boolean, return an error """ - body = json.dumps({"purge": "NotBool"}) channel = self.make_request( "DELETE", self.url, - content=body, + content={"purge": "NotBool"}, access_token=self.admin_user_tok, ) @@ -198,12 +193,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": True, "purge": True}) - channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body, + content={"block": True, "purge": True}, access_token=self.admin_user_tok, ) @@ -231,12 +224,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": False, "purge": True}) - channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body, + content={"block": False, "purge": True}, access_token=self.admin_user_tok, ) @@ -265,12 +256,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": True, "purge": False}) - channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body, + content={"block": True, "purge": False}, access_token=self.admin_user_tok, ) @@ -342,7 +331,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - json.dumps({"new_room_user_id": self.admin_user}), + {"new_room_user_id": self.admin_user}, access_token=self.admin_user_tok, ) @@ -372,7 +361,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), + {"history_visibility": "world_readable"}, access_token=self.other_user_tok, ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) @@ -388,7 +377,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - json.dumps({"new_room_user_id": self.admin_user}), + {"new_room_user_id": self.admin_user}, access_token=self.admin_user_tok, ) @@ -1782,12 +1771,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", self.url, - content=body, + content={"user_id": self.second_user_id}, access_token=self.second_tok, ) @@ -1798,12 +1786,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ If a parameter is missing, return an error """ - body = json.dumps({"unknown_parameter": "@unknown:test"}) channel = self.make_request( "POST", self.url, - content=body, + content={"unknown_parameter": "@unknown:test"}, access_token=self.admin_user_tok, ) @@ -1814,12 +1801,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ - body = json.dumps({"user_id": "@unknown:test"}) channel = self.make_request( "POST", self.url, - content=body, + content={"user_id": "@unknown:test"}, access_token=self.admin_user_tok, ) @@ -1830,12 +1816,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ Check that only local user can join rooms. """ - body = json.dumps({"user_id": "@not:exist.bla"}) channel = self.make_request( "POST", self.url, - content=body, + content={"user_id": "@not:exist.bla"}, access_token=self.admin_user_tok, ) @@ -1849,13 +1834,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. """ - body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/!unknown:test" channel = self.make_request( "POST", url, - content=body, + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) @@ -1866,13 +1850,12 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. """ - body = json.dumps({"user_id": self.second_user_id}) url = "/_synapse/admin/v1/join/invalidroom" channel = self.make_request( "POST", url, - content=body, + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) @@ -1886,12 +1869,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): """ Test joining a local user to a public room with "JoinRules.PUBLIC" """ - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", self.url, - content=body, + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) @@ -1917,12 +1899,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.creator, tok=self.creator_tok, is_public=False ) url = f"/_synapse/admin/v1/join/{private_room_id}" - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content=body, + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) @@ -1960,12 +1941,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Join user to room. url = f"/_synapse/admin/v1/join/{private_room_id}" - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content=body, + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) @@ -1990,12 +1970,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.admin_user, tok=self.admin_user_tok, is_public=False ) url = f"/_synapse/admin/v1/join/{private_room_id}" - body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( "POST", url, - content=body, + content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) -- cgit 1.5.1 From 7ff22d6da41cd5ca80db95c18b409aea38e49fcd Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 30 Nov 2021 16:28:02 +0000 Subject: Fix `LruCache` corruption bug with a `size_callback` that can return 0 (#11454) When all entries in an `LruCache` have a size of 0 according to the provided `size_callback`, and `drop_from_cache` is called on a cache node, the node would be unlinked from the LRU linked list but remain in the cache dictionary. An assertion would be later be tripped due to the inconsistency. Avoid unintentionally calling `__len__` and use a strict `is None` check instead when unwrapping the weak reference. --- changelog.d/11454.bugfix | 1 + synapse/util/caches/lrucache.py | 5 ++++- tests/util/test_lrucache.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11454.bugfix (limited to 'tests') diff --git a/changelog.d/11454.bugfix b/changelog.d/11454.bugfix new file mode 100644 index 0000000000..096265cbc9 --- /dev/null +++ b/changelog.d/11454.bugfix @@ -0,0 +1 @@ +Fix an `LruCache` corruption bug, introduced in 1.38.0, that would cause certain requests to fail until the next Synapse restart. diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 05c4dcb062..eb96f7e665 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -271,7 +271,10 @@ class _Node(Generic[KT, VT]): removed from all lists. """ cache = self._cache() - if not cache or not cache.pop(self.key, None): + if ( + cache is None + or cache.pop(self.key, _Sentinel.sentinel) is _Sentinel.sentinel + ): # `cache.pop` should call `drop_from_lists()`, unless this Node had # already been removed from the cache. self.drop_from_lists() diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 6578f3411e..291644eb7d 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -13,6 +13,7 @@ # limitations under the License. +from typing import List from unittest.mock import Mock from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries @@ -261,6 +262,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): self.assertEquals(cache["key4"], [4]) self.assertEquals(cache["key5"], [5, 6]) + def test_zero_size_drop_from_cache(self) -> None: + """Test that `drop_from_cache` works correctly with 0-sized entries.""" + cache: LruCache[str, List[int]] = LruCache(5, size_callback=lambda x: 0) + cache["key1"] = [] + + self.assertEqual(len(cache), 0) + cache.cache["key1"].drop_from_cache() + self.assertIsNone( + cache.pop("key1"), "Cache entry should have been evicted but wasn't" + ) + class TimeEvictionTestCase(unittest.HomeserverTestCase): """Test that time based eviction works correctly.""" -- cgit 1.5.1 From 379f2650cf875f50c59524147ec0e33cfd5ef60c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 30 Nov 2021 11:33:33 -0500 Subject: Bundle relations of relations into the `/relations` result. (#11284) Per updates to MSC2675 which now states that bundled aggregations should be included from the `/relations` endpoint. --- changelog.d/11284.feature | 1 + synapse/events/utils.py | 8 +++ synapse/rest/client/relations.py | 9 +-- tests/rest/client/test_relations.py | 118 ++++++++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+), 6 deletions(-) create mode 100644 changelog.d/11284.feature (limited to 'tests') diff --git a/changelog.d/11284.feature b/changelog.d/11284.feature new file mode 100644 index 0000000000..cbaa5a988c --- /dev/null +++ b/changelog.d/11284.feature @@ -0,0 +1 @@ +When returning relation events from the `/relations` API, bundle any relations of those relations into the result, per updates to [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e5967c995e..05219a9dd0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -435,6 +435,14 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ + # Do not bundle relations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return + event_id = event.event_id # The bundled relations to include. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 45e9f1dd90..b1a3304849 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -230,12 +230,9 @@ class RelationPaginationServlet(RestServlet): original_event = await self._event_serializer.serialize_event( event, now, bundle_relations=False ) - # Similarly, we don't allow relations to be applied to relations, so we - # return the original relations without any aggregations on top of them - # here. - serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_relations=False - ) + # The relations returned for the requested event do include their + # bundled relations. + serialized_events = await self._event_serializer.serialize_events(events, now) return_value = pagination_chunk.to_dict() return_value["chunk"] = serialized_events diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index eb10d43217..b494da5138 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -526,6 +526,74 @@ class RelationsTestCase(unittest.HomeserverTestCase): }, ) + def test_aggregation_get_event_for_annotation(self): + """Test that annotations do not get bundled relations included + when directly requested. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + # Annotate the annotation. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{annotation_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) + + def test_aggregation_get_event_for_thread(self): + """Test that threads get bundled relations included when directly requested.""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + thread_id = channel.json_body["event_id"] + + # Annotate the annotation. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{thread_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + # It should also be included when the entire thread is requested. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + thread_message = channel.json_body["chunk"][0] + self.assertEquals( + thread_message["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + def test_edit(self): """Test that a simple edit works.""" @@ -672,6 +740,56 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + def test_edit_edit(self): + """Test that an edit cannot be edited.""" + new_body = {"msgtype": "m.text", "body": "Initial edit"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": new_body, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + edit_event_id = channel.json_body["event_id"] + + # Edit the edit event. + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "foo", + "m.new_content": {"msgtype": "m.text", "body": "Ignored edit"}, + }, + parent_id=edit_event_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Request the original event. + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + # The edit to the edit should be ignored. + self.assertEquals(channel.json_body["content"], new_body) + + # The relations information should not include the edit to the edit. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + def test_relations_redaction_redacts_edits(self): """Test that edits of an event are redacted when the original event is redacted. -- cgit 1.5.1 From 70cbb1a5e311f609b624e3fae1a1712db639c51e Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 30 Nov 2021 10:12:18 -0800 Subject: Don't start Synapse master process if `worker_app` is set (#11416) * Add check to catch syanpse master process starting when workers are configured * add test to verify that starting master process with worker config raises error * newsfragment * specify config.worker.worker_app in check * update test * report specific config option that triggered the error Co-authored-by: reivilibre * clarify error message Co-authored-by: reivilibre Co-authored-by: reivilibre --- changelog.d/11416.misc | 1 + synapse/app/homeserver.py | 7 +++++++ tests/app/test_homeserver_start.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 changelog.d/11416.misc create mode 100644 tests/app/test_homeserver_start.py (limited to 'tests') diff --git a/changelog.d/11416.misc b/changelog.d/11416.misc new file mode 100644 index 0000000000..a5c3aeda83 --- /dev/null +++ b/changelog.d/11416.misc @@ -0,0 +1 @@ +Add a check to ensure that users cannot start the Synapse master process when `worker_app` is set. \ No newline at end of file diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 52541faab2..dd76e07321 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -358,6 +358,13 @@ def setup(config_options: List[str]) -> SynapseHomeServer: # generating config files and shouldn't try to continue. sys.exit(0) + if config.worker.worker_app: + raise ConfigError( + "You have specified `worker_app` in the config but are attempting to start a non-worker " + "instance. Please use `python -m synapse.app.generic_worker` instead (or remove the option if this is the main process)." + ) + sys.exit(1) + events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py new file mode 100644 index 0000000000..cbcada0451 --- /dev/null +++ b/tests/app/test_homeserver_start.py @@ -0,0 +1,31 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.app.homeserver +from synapse.config._base import ConfigError + +from tests.config.utils import ConfigFileTestCase + + +class HomeserverAppStartTestCase(ConfigFileTestCase): + def test_wrong_start_caught(self): + # Generate a config with a worker_app + self.generate_config() + # Add a blank line as otherwise the next addition ends up on a line with a comment + self.add_lines_to_config([" "]) + self.add_lines_to_config(["worker_app: test_worker_app"]) + + # Ensure that starting master process with worker config raises an exception + with self.assertRaises(ConfigError): + synapse.app.homeserver.setup(["-c", self.config_file]) -- cgit 1.5.1 From ed635d32853ee0a3e5ec1078679b27e7844a4ac7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Dec 2021 12:51:14 -0500 Subject: Add a license header and comment. (#11479) --- changelog.d/11479.feature | 1 + tests/storage/test_background_update.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+) create mode 100644 changelog.d/11479.feature (limited to 'tests') diff --git a/changelog.d/11479.feature b/changelog.d/11479.feature new file mode 100644 index 0000000000..aba3292015 --- /dev/null +++ b/changelog.d/11479.feature @@ -0,0 +1 @@ +Add plugin support for controlling database background updates. diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 216d816d56..d77c001506 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,3 +1,18 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Use backported mock for AsyncMock support on Python 3.6. from mock import Mock from twisted.internet.defer import Deferred, ensureDeferred -- cgit 1.5.1 From 435f04480728c5d982e1a63c1b2777784bf9cd26 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 2 Dec 2021 15:30:05 +0000 Subject: Add type annotations to `tests.storage.test_appservice`. (#11488) --- changelog.d/11488.misc | 1 + mypy.ini | 1 - synapse/appservice/__init__.py | 3 +- synapse/storage/databases/main/appservice.py | 6 +- tests/storage/test_appservice.py | 140 ++++++++++++++++++--------- 5 files changed, 98 insertions(+), 53 deletions(-) create mode 100644 changelog.d/11488.misc (limited to 'tests') diff --git a/changelog.d/11488.misc b/changelog.d/11488.misc new file mode 100644 index 0000000000..c14a7d2e98 --- /dev/null +++ b/changelog.d/11488.misc @@ -0,0 +1 @@ +Add type annotations to `tests.storage.test_appservice`. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 51056a8f64..99b5c41ad6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -111,7 +111,6 @@ exclude = (?x) |tests/server_notices/test_resource_limits_server_notices.py |tests/state/test_v2.py |tests/storage/test_account_data.py - |tests/storage/test_appservice.py |tests/storage/test_background_update.py |tests/storage/test_base.py |tests/storage/test_client_ips.py diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 6504c6bd3f..f9d3bd337d 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import re +from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Match, Optional from synapse.api.constants import EventTypes @@ -27,7 +28,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ApplicationServiceState: +class ApplicationServiceState(Enum): DOWN = "down" UP = "up" diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index baec35ee27..4a883dc166 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore( A list of ApplicationServices, which may be empty. """ results = await self.db_pool.simple_select_list( - "application_services_state", {"state": state}, ["as_id"] + "application_services_state", {"state": state.value}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore( desc="get_appservice_state", ) if result: - return result.get("state") + return ApplicationServiceState(result.get("state")) return None async def set_appservice_state( @@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore( state: The connectivity state to apply. """ await self.db_pool.simple_upsert( - "application_services_state", {"as_id": service.id}, {"state": state} + "application_services_state", {"as_id": service.id}, {"state": state.value} ) async def create_appservice_txn( diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index f26d5acf9c..4b20a28ca2 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -14,19 +14,25 @@ import json import os import tempfile +from typing import Any, Generator, List, Optional, cast from unittest.mock import Mock import yaml from twisted.internet import defer +from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError +from synapse.events import EventBase +from synapse.server import HomeServer from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -36,7 +42,7 @@ from tests.utils import setup_test_homeserver class ApplicationServiceStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.as_yaml_files = [] + self.as_yaml_files: List[str] = [] hs = yield setup_test_homeserver( self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) @@ -58,7 +64,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): database, make_conn(database._database_config, database.engine, "test"), hs ) - def tearDown(self): + def tearDown(self) -> None: # TODO: suboptimal that we need to create files for tests! for f in self.as_yaml_files: try: @@ -66,7 +72,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): except Exception: pass - def _add_appservice(self, as_token, id, url, hs_token, sender): + def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: as_yaml = { "url": url, "as_token": as_token, @@ -80,12 +86,13 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def test_retrieve_unknown_service_token(self): + def test_retrieve_unknown_service_token(self) -> None: service = self.store.get_app_service_by_token("invalid_token") self.assertEquals(service, None) - def test_retrieval_of_service(self): + def test_retrieval_of_service(self) -> None: stored_service = self.store.get_app_service_by_token(self.as_token) + assert stored_service is not None self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.id, self.as_id) self.assertEquals(stored_service.url, self.as_url) @@ -93,7 +100,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], []) self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], []) - def test_retrieval_of_all_services(self): + def test_retrieval_of_all_services(self) -> None: services = self.store.get_app_services() self.assertEquals(len(services), 3) @@ -101,7 +108,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.as_yaml_files = [] + self.as_yaml_files: List[str] = [] hs = yield setup_test_homeserver( self.addCleanup, federation_sender=Mock(), federation_client=Mock() @@ -117,7 +124,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): {"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"}, ] for s in self.as_list: - yield self._add_service(s["url"], s["token"], s["id"]) + self._add_service(s["url"], s["token"], s["id"]) self.as_yaml_files = [] @@ -131,7 +138,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): database, make_conn(db_config, self.engine, "test"), hs ) - def _add_service(self, url, as_token, id): + def _add_service(self, url, as_token, id) -> None: as_yaml = { "url": url, "as_token": as_token, @@ -145,13 +152,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): outfile.write(yaml.dump(as_yaml)) self.as_yaml_files.append(as_token) - def _set_state(self, id, state, txn=None): + def _set_state( + self, id: str, state: ApplicationServiceState, txn: Optional[int] = None + ): return self.db_pool.runOperation( self.engine.convert_param_style( "INSERT INTO application_services_state(as_id, state, last_txn) " "VALUES(?,?,?)" ), - (id, state, txn), + (id, state.value, txn), ) def _insert_txn(self, as_id, txn_id, events): @@ -169,24 +178,30 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): "INSERT INTO application_services_state(as_id, last_txn, state) " "VALUES(?,?,?)" ), - (as_id, txn_id, ApplicationServiceState.UP), + (as_id, txn_id, ApplicationServiceState.UP.value), ) @defer.inlineCallbacks - def test_get_appservice_state_none(self): + def test_get_appservice_state_none( + self, + ) -> Generator["Deferred[object]", object, None]: service = Mock(id="999") state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(None, state) @defer.inlineCallbacks - def test_get_appservice_state_up(self): + def test_get_appservice_state_up( + self, + ) -> Generator["Deferred[object]", object, None]: yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.UP, state) @defer.inlineCallbacks - def test_get_appservice_state_down(self): + def test_get_appservice_state_down( + self, + ) -> Generator["Deferred[object]", object, None]: yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) @@ -195,14 +210,18 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(ApplicationServiceState.DOWN, state) @defer.inlineCallbacks - def test_get_appservices_by_state_none(self): + def test_get_appservices_by_state_none( + self, + ) -> Generator["Deferred[object]", Any, None]: services = yield defer.ensureDeferred( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) @defer.inlineCallbacks - def test_set_appservices_state_down(self): + def test_set_appservices_state_down( + self, + ) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[1]["id"]) yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.DOWN) @@ -211,12 +230,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" ), - (ApplicationServiceState.DOWN,), + (ApplicationServiceState.DOWN.value,), ) self.assertEquals(service.id, rows[0][0]) @defer.inlineCallbacks - def test_set_appservices_state_multiple_up(self): + def test_set_appservices_state_multiple_up( + self, + ) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[1]["id"]) yield defer.ensureDeferred( self.store.set_appservice_state(service, ApplicationServiceState.UP) @@ -231,14 +252,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" ), - (ApplicationServiceState.UP,), + (ApplicationServiceState.UP.value,), ) self.assertEquals(service.id, rows[0][0]) @defer.inlineCallbacks - def test_create_appservice_txn_first(self): + def test_create_appservice_txn_first( + self, + ) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[0]["id"]) - events = [Mock(event_id="e1"), Mock(event_id="e2")] + events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = yield defer.ensureDeferred( self.store.create_appservice_txn(service, events, []) ) @@ -247,9 +270,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(txn.service, service) @defer.inlineCallbacks - def test_create_appservice_txn_older_last_txn(self): + def test_create_appservice_txn_older_last_txn( + self, + ) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[0]["id"]) - events = [Mock(event_id="e1"), Mock(event_id="e2")] + events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) yield self._set_last_txn(service.id, 9643) # AS is falling behind yield self._insert_txn(service.id, 9644, events) yield self._insert_txn(service.id, 9645, events) @@ -261,9 +286,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(txn.service, service) @defer.inlineCallbacks - def test_create_appservice_txn_up_to_date_last_txn(self): + def test_create_appservice_txn_up_to_date_last_txn( + self, + ) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[0]["id"]) - events = [Mock(event_id="e1"), Mock(event_id="e2")] + events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) yield self._set_last_txn(service.id, 9643) txn = yield defer.ensureDeferred( self.store.create_appservice_txn(service, events, []) @@ -273,9 +300,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(txn.service, service) @defer.inlineCallbacks - def test_create_appservice_txn_up_fuzzing(self): + def test_create_appservice_txn_up_fuzzing( + self, + ) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[0]["id"]) - events = [Mock(event_id="e1"), Mock(event_id="e2")] + events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) yield self._set_last_txn(service.id, 9643) # dump in rows with higher IDs to make sure the queries aren't wrong. @@ -296,7 +325,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(txn.service, service) @defer.inlineCallbacks - def test_complete_appservice_txn_first_txn(self): + def test_complete_appservice_txn_first_txn( + self, + ) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 1 @@ -324,7 +355,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(0, len(res)) @defer.inlineCallbacks - def test_complete_appservice_txn_existing_in_state_table(self): + def test_complete_appservice_txn_existing_in_state_table( + self, + ) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 5 @@ -342,7 +375,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) - self.assertEquals(ApplicationServiceState.UP, res[0][1]) + self.assertEquals(ApplicationServiceState.UP.value, res[0][1]) res = yield self.db_pool.runQuery( self.engine.convert_param_style( @@ -353,20 +386,23 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(0, len(res)) @defer.inlineCallbacks - def test_get_oldest_unsent_txn_none(self): + def test_get_oldest_unsent_txn_none( + self, + ) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[0]["id"]) txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) @defer.inlineCallbacks - def test_get_oldest_unsent_txn(self): + def test_get_oldest_unsent_txn(self) -> Generator["Deferred[object]", Any, None]: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) + # (ignore needed because Mypy won't allow us to assign to a method otherwise) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) @@ -379,7 +415,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(events, txn.events) @defer.inlineCallbacks - def test_get_appservices_by_state_single(self): + def test_get_appservices_by_state_single( + self, + ) -> Generator["Deferred[object]", Any, None]: yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) @@ -390,7 +428,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.assertEquals(self.as_list[0]["id"], services[0].id) @defer.inlineCallbacks - def test_get_appservices_by_state_multiple(self): + def test_get_appservices_by_state_multiple( + self, + ) -> Generator["Deferred[object]", Any, None]: yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) @@ -407,16 +447,20 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor, clock) -> "HomeServer": hs = self.setup_test_homeserver() return hs - def prepare(self, hs, reactor, clock): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.service = Mock(id="foo") self.store = self.hs.get_datastore() - self.get_success(self.store.set_appservice_state(self.service, "up")) + self.get_success( + self.store.set_appservice_state(self.service, ApplicationServiceState.UP) + ) - def test_get_type_stream_id_for_appservice_no_value(self): + def test_get_type_stream_id_for_appservice_no_value(self) -> None: value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) @@ -427,13 +471,13 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ) self.assertEquals(value, 0) - def test_get_type_stream_id_for_appservice_invalid_type(self): + def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( self.store.get_type_stream_id_for_appservice(self.service, "foobar"), ValueError, ) - def test_set_type_stream_id_for_appservice(self): + def test_set_type_stream_id_for_appservice(self) -> None: read_receipt_value = 1024 self.get_success( self.store.set_type_stream_id_for_appservice( @@ -455,7 +499,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ) self.assertEqual(result, read_receipt_value) - def test_set_type_stream_id_for_appservice_invalid_type(self): + def test_set_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024), ValueError, @@ -464,12 +508,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs) -> None: super().__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): - def _write_config(self, suffix, **kwargs): + def _write_config(self, suffix, **kwargs) -> str: vals = { "id": "id" + suffix, "url": "url" + suffix, @@ -486,7 +530,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): return path @defer.inlineCallbacks - def test_unique_works(self): + def test_unique_works(self) -> Generator["Deferred[object]", Any, None]: f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") @@ -503,7 +547,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) @defer.inlineCallbacks - def test_duplicate_ids(self): + def test_duplicate_ids(self) -> Generator["Deferred[object]", Any, None]: f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") @@ -528,7 +572,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): self.assertIn("id", str(e)) @defer.inlineCallbacks - def test_duplicate_as_tokens(self): + def test_duplicate_as_tokens(self) -> Generator["Deferred[object]", Any, None]: f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") -- cgit 1.5.1 From 858d80bf0f9f656a03992794874081b806e49222 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 2 Dec 2021 16:05:24 +0000 Subject: Fix media repository failing when media store path contains symlinks (#11446) --- changelog.d/11446.bugfix | 1 + synapse/rest/media/v1/filepath.py | 115 +++++++++++++++++++++-------------- tests/rest/media/v1/test_filepath.py | 109 ++++++++++++++++++++++++++++++++- 3 files changed, 180 insertions(+), 45 deletions(-) create mode 100644 changelog.d/11446.bugfix (limited to 'tests') diff --git a/changelog.d/11446.bugfix b/changelog.d/11446.bugfix new file mode 100644 index 0000000000..fa5e055d50 --- /dev/null +++ b/changelog.d/11446.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in 1.47.1 where the media repository would fail to work if the media store path contained any symbolic links. diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index c0e15c6513..1f6441c412 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -43,47 +43,75 @@ GetPathMethod = TypeVar( ) -def _wrap_with_jail_check(func: GetPathMethod) -> GetPathMethod: +def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]: """Wraps a path-returning method to check that the returned path(s) do not escape the media store directory. + The path-returning method may return either a single path, or a list of paths. + The check is not expected to ever fail, unless `func` is missing a call to `_validate_path_component`, or `_validate_path_component` is buggy. Args: - func: The `MediaFilePaths` method to wrap. The method may return either a single - path, or a list of paths. Returned paths may be either absolute or relative. + relative: A boolean indicating whether the wrapped method returns paths relative + to the media store directory. Returns: - The method, wrapped with a check to ensure that the returned path(s) lie within - the media store directory. Raises a `ValueError` if the check fails. + A method which will wrap a path-returning method, adding a check to ensure that + the returned path(s) lie within the media store directory. The check will raise + a `ValueError` if it fails. """ - @functools.wraps(func) - def _wrapped( - self: "MediaFilePaths", *args: Any, **kwargs: Any - ) -> Union[str, List[str]]: - path_or_paths = func(self, *args, **kwargs) - - if isinstance(path_or_paths, list): - paths_to_check = path_or_paths - else: - paths_to_check = [path_or_paths] - - for path in paths_to_check: - # path may be an absolute or relative path, depending on the method being - # wrapped. When "appending" an absolute path, `os.path.join` discards the - # previous path, which is desired here. - normalized_path = os.path.normpath(os.path.join(self.real_base_path, path)) - if ( - os.path.commonpath([normalized_path, self.real_base_path]) - != self.real_base_path - ): - raise ValueError(f"Invalid media store path: {path!r}") - - return path_or_paths - - return cast(GetPathMethod, _wrapped) + def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod: + @functools.wraps(func) + def _wrapped( + self: "MediaFilePaths", *args: Any, **kwargs: Any + ) -> Union[str, List[str]]: + path_or_paths = func(self, *args, **kwargs) + + if isinstance(path_or_paths, list): + paths_to_check = path_or_paths + else: + paths_to_check = [path_or_paths] + + for path in paths_to_check: + # Construct the path that will ultimately be used. + # We cannot guess whether `path` is relative to the media store + # directory, since the media store directory may itself be a relative + # path. + if relative: + path = os.path.join(self.base_path, path) + normalized_path = os.path.normpath(path) + + # Now that `normpath` has eliminated `../`s and `./`s from the path, + # `os.path.commonpath` can be used to check whether it lies within the + # media store directory. + if ( + os.path.commonpath([normalized_path, self.normalized_base_path]) + != self.normalized_base_path + ): + # The path resolves to outside the media store directory, + # or `self.base_path` is `.`, which is an unlikely configuration. + raise ValueError(f"Invalid media store path: {path!r}") + + # Note that `os.path.normpath`/`abspath` has a subtle caveat: + # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a + # different path if `a/b/c` is a symlink. That is, the check above is + # not perfect and may allow a certain restricted subset of untrustworthy + # paths through. Since the check above is secondary to the main + # `_validate_path_component` checks, it's less important for it to be + # perfect. + # + # As an alternative, `os.path.realpath` will resolve symlinks, but + # proves problematic if there are symlinks inside the media store. + # eg. if `url_store/` is symlinked to elsewhere, its canonical path + # won't match that of the main media store directory. + + return path_or_paths + + return cast(GetPathMethod, _wrapped) + + return _wrap_with_jail_check_inner ALLOWED_CHARACTERS = set( @@ -127,9 +155,7 @@ class MediaFilePaths: def __init__(self, primary_base_path: str): self.base_path = primary_base_path - - # The media store directory, with all symlinks resolved. - self.real_base_path = os.path.realpath(primary_base_path) + self.normalized_base_path = os.path.normpath(self.base_path) # Refuse to initialize if paths cannot be validated correctly for the current # platform. @@ -140,7 +166,7 @@ class MediaFilePaths: # for certain homeservers there, since ":"s aren't allowed in paths. assert os.name == "posix" - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def local_media_filepath_rel(self, media_id: str) -> str: return os.path.join( "local_content", @@ -151,7 +177,7 @@ class MediaFilePaths: local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def local_media_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: @@ -167,7 +193,7 @@ class MediaFilePaths: local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=False) def local_media_thumbnail_dir(self, media_id: str) -> str: """ Retrieve the local store path of thumbnails of a given media_id @@ -185,7 +211,7 @@ class MediaFilePaths: _validate_path_component(media_id[4:]), ) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str: return os.path.join( "remote_content", @@ -197,7 +223,7 @@ class MediaFilePaths: remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def remote_media_thumbnail_rel( self, server_name: str, @@ -223,7 +249,7 @@ class MediaFilePaths: # Legacy path that was used to store thumbnails previously. # Should be removed after some time, when most of the thumbnails are stored # using the new path. - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def remote_media_thumbnail_rel_legacy( self, server_name: str, file_id: str, width: int, height: int, content_type: str ) -> str: @@ -238,6 +264,7 @@ class MediaFilePaths: _validate_path_component(file_name), ) + @_wrap_with_jail_check(relative=False) def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str: return os.path.join( self.base_path, @@ -248,7 +275,7 @@ class MediaFilePaths: _validate_path_component(file_id[4:]), ) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def url_cache_filepath_rel(self, media_id: str) -> str: if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form @@ -268,7 +295,7 @@ class MediaFilePaths: url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=False) def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): @@ -290,7 +317,7 @@ class MediaFilePaths: ), ] - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def url_cache_thumbnail_rel( self, media_id: str, width: int, height: int, content_type: str, method: str ) -> str: @@ -318,7 +345,7 @@ class MediaFilePaths: url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=True) def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -341,7 +368,7 @@ class MediaFilePaths: url_cache_thumbnail_directory_rel ) - @_wrap_with_jail_check + @_wrap_with_jail_check(relative=False) def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py index 8fe94f7d85..913bc530aa 100644 --- a/tests/rest/media/v1/test_filepath.py +++ b/tests/rest/media/v1/test_filepath.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import os from typing import Iterable -from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check from tests import unittest @@ -486,3 +487,109 @@ class MediaFilePathsTestCase(unittest.TestCase): f"{value!r} unexpectedly passed validation: " f"{method} returned {path_or_list!r}" ) + + +class MediaFilePathsJailTestCase(unittest.TestCase): + def _check_relative_path(self, filepaths: MediaFilePaths, path: str) -> None: + """Passes a relative path through the jail check. + + Args: + filepaths: The `MediaFilePaths` instance. + path: A path relative to the media store directory. + + Raises: + ValueError: If the jail check fails. + """ + + @_wrap_with_jail_check(relative=True) + def _make_relative_path(self: MediaFilePaths, path: str) -> str: + return path + + _make_relative_path(filepaths, path) + + def _check_absolute_path(self, filepaths: MediaFilePaths, path: str) -> None: + """Passes an absolute path through the jail check. + + Args: + filepaths: The `MediaFilePaths` instance. + path: A path relative to the media store directory. + + Raises: + ValueError: If the jail check fails. + """ + + @_wrap_with_jail_check(relative=False) + def _make_absolute_path(self: MediaFilePaths, path: str) -> str: + return os.path.join(self.base_path, path) + + _make_absolute_path(filepaths, path) + + def test_traversal_inside(self) -> None: + """Test the jail check for paths that stay within the media directory.""" + # Despite the `../`s, these paths still lie within the media directory and it's + # expected for the jail check to allow them through. + # These paths ought to trip the other checks in place and should never be + # returned. + filepaths = MediaFilePaths("/media_store") + path = "url_cache/2020-01-02/../../GerZNDnDZVjsOtar" + self._check_relative_path(filepaths, path) + self._check_absolute_path(filepaths, path) + + def test_traversal_outside(self) -> None: + """Test that the jail check fails for paths that escape the media directory.""" + filepaths = MediaFilePaths("/media_store") + path = "url_cache/2020-01-02/../../../GerZNDnDZVjsOtar" + with self.assertRaises(ValueError): + self._check_relative_path(filepaths, path) + with self.assertRaises(ValueError): + self._check_absolute_path(filepaths, path) + + def test_traversal_reentry(self) -> None: + """Test the jail check for paths that exit and re-enter the media directory.""" + # These paths lie outside the media directory if it is a symlink, and inside + # otherwise. Ideally the check should fail, but this proves difficult. + # This test documents the behaviour for this edge case. + # These paths ought to trip the other checks in place and should never be + # returned. + filepaths = MediaFilePaths("/media_store") + path = "url_cache/2020-01-02/../../../media_store/GerZNDnDZVjsOtar" + self._check_relative_path(filepaths, path) + self._check_absolute_path(filepaths, path) + + def test_symlink(self) -> None: + """Test that a symlink does not cause the jail check to fail.""" + media_store_path = self.mktemp() + + # symlink the media store directory + os.symlink("/mnt/synapse/media_store", media_store_path) + + # Test that relative and absolute paths don't trip the check + # NB: `media_store_path` is a relative path + filepaths = MediaFilePaths(media_store_path) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + + filepaths = MediaFilePaths(os.path.abspath(media_store_path)) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + + def test_symlink_subdirectory(self) -> None: + """Test that a symlinked subdirectory does not cause the jail check to fail.""" + media_store_path = self.mktemp() + os.mkdir(media_store_path) + + # symlink `url_cache/` + os.symlink( + "/mnt/synapse/media_store_url_cache", + os.path.join(media_store_path, "url_cache"), + ) + + # Test that relative and absolute paths don't trip the check + # NB: `media_store_path` is a relative path + filepaths = MediaFilePaths(media_store_path) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + + filepaths = MediaFilePaths(os.path.abspath(media_store_path)) + self._check_relative_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") + self._check_absolute_path(filepaths, "url_cache/2020-01-02/GerZNDnDZVjsOtar") -- cgit 1.5.1 From 8a4c2969874c0b7d72003f2523883eba8a348e83 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 2 Dec 2021 18:13:30 +0000 Subject: Clean up `tests.test_visibility` to remove legacy code. (#11495) --- changelog.d/11495.misc | 1 + mypy.ini | 1 - tests/test_visibility.py | 241 ++++++++--------------------------------------- 3 files changed, 40 insertions(+), 203 deletions(-) create mode 100644 changelog.d/11495.misc (limited to 'tests') diff --git a/changelog.d/11495.misc b/changelog.d/11495.misc new file mode 100644 index 0000000000..5b52697fb4 --- /dev/null +++ b/changelog.d/11495.misc @@ -0,0 +1 @@ +Clean up `tests.test_visibility` to remove legacy code. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index d8296b4fa3..fea71d154c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -123,7 +123,6 @@ exclude = (?x) |tests/test_server.py |tests/test_state.py |tests/test_terms_auth.py - |tests/test_visibility.py |tests/unittest.py |tests/util/caches/test_cached_call.py |tests/util/caches/test_deferred_cache.py diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 94b19788d7..e0b08d67d4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -13,35 +13,30 @@ # limitations under the License. import logging from typing import Optional -from unittest.mock import Mock - -from twisted.internet import defer -from twisted.internet.defer import succeed from synapse.api.room_versions import RoomVersions -from synapse.events import FrozenEvent +from synapse.events import EventBase +from synapse.types import JsonDict from synapse.visibility import filter_events_for_server -import tests.unittest -from tests.utils import create_room, setup_test_homeserver +from tests import unittest +from tests.utils import create_room logger = logging.getLogger(__name__) TEST_ROOM_ID = "!TEST:ROOM" -class FilterEventsForServerTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) +class FilterEventsForServerTestCase(unittest.HomeserverTestCase): + def setUp(self) -> None: + super(FilterEventsForServerTestCase, self).setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self.storage = self.hs.get_storage() - yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) + self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) - @defer.inlineCallbacks - def test_filtering(self): + def test_filtering(self) -> None: # # The events to be filtered consist of 10 membership events (it doesn't # really matter if they are joins or leaves, so let's make them joins). @@ -51,18 +46,20 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # # before we do that, we persist some other events to act as state. - yield self.inject_visibility("@admin:hs", "joined") + self.get_success(self._inject_visibility("@admin:hs", "joined")) for i in range(0, 10): - yield self.inject_room_member("@resident%i:hs" % i) + self.get_success(self._inject_room_member("@resident%i:hs" % i)) events_to_filter = [] for i in range(0, 10): user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") - evt = yield self.inject_room_member(user, extra_content={"a": "b"}) + evt = self.get_success( + self._inject_room_member(user, extra_content={"a": "b"}) + ) events_to_filter.append(evt) - filtered = yield defer.ensureDeferred( + filtered = self.get_success( filter_events_for_server(self.storage, "test_server", events_to_filter) ) @@ -75,34 +72,31 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") - @defer.inlineCallbacks - def test_erased_user(self): + def test_erased_user(self) -> None: # 4 message events, from erased and unerased users, with a membership # change in the middle of them. events_to_filter = [] - evt = yield self.inject_message("@unerased:local_hs") + evt = self.get_success(self._inject_message("@unerased:local_hs")) events_to_filter.append(evt) - evt = yield self.inject_message("@erased:local_hs") + evt = self.get_success(self._inject_message("@erased:local_hs")) events_to_filter.append(evt) - evt = yield self.inject_room_member("@joiner:remote_hs") + evt = self.get_success(self._inject_room_member("@joiner:remote_hs")) events_to_filter.append(evt) - evt = yield self.inject_message("@unerased:local_hs") + evt = self.get_success(self._inject_message("@unerased:local_hs")) events_to_filter.append(evt) - evt = yield self.inject_message("@erased:local_hs") + evt = self.get_success(self._inject_message("@erased:local_hs")) events_to_filter.append(evt) # the erasey user gets erased - yield defer.ensureDeferred( - self.hs.get_datastore().mark_user_erased("@erased:local_hs") - ) + self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs")) # ... and the filtering happens. - filtered = yield defer.ensureDeferred( + filtered = self.get_success( filter_events_for_server(self.storage, "test_server", events_to_filter) ) @@ -123,8 +117,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): for i in (1, 4): self.assertNotIn("body", filtered[i].content) - @defer.inlineCallbacks - def inject_visibility(self, user_id, visibility): + def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: content = {"history_visibility": visibility} builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -137,18 +130,18 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield defer.ensureDeferred( + event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - yield defer.ensureDeferred( - self.storage.persistence.persist_event(event, context) - ) + self.get_success(self.storage.persistence.persist_event(event, context)) return event - @defer.inlineCallbacks - def inject_room_member( - self, user_id, membership="join", extra_content: Optional[dict] = None - ): + def _inject_room_member( + self, + user_id: str, + membership: str = "join", + extra_content: Optional[JsonDict] = None, + ) -> EventBase: content = {"membership": membership} content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( @@ -162,17 +155,16 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield defer.ensureDeferred( + event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - yield defer.ensureDeferred( - self.storage.persistence.persist_event(event, context) - ) + self.get_success(self.storage.persistence.persist_event(event, context)) return event - @defer.inlineCallbacks - def inject_message(self, user_id, content=None): + def _inject_message( + self, user_id: str, content: Optional[JsonDict] = None + ) -> EventBase: if content is None: content = {"body": "testytest", "msgtype": "m.text"} builder = self.event_builder_factory.for_room_version( @@ -185,164 +177,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield defer.ensureDeferred( + event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - yield defer.ensureDeferred( - self.storage.persistence.persist_event(event, context) - ) + self.get_success(self.storage.persistence.persist_event(event, context)) return event - - @defer.inlineCallbacks - def test_large_room(self): - # see what happens when we have a large room with hundreds of thousands - # of membership events - - # As above, the events to be filtered consist of 10 membership events, - # where one of them is for a user on the server we are filtering for. - - import cProfile - import pstats - import time - - # we stub out the store, because building up all that state the normal - # way is very slow. - test_store = _TestStore() - - # our initial state is 100000 membership events and one - # history_visibility event. - room_state = [] - - history_visibility_evt = FrozenEvent( - { - "event_id": "$history_vis", - "type": "m.room.history_visibility", - "sender": "@resident_user_0:test.com", - "state_key": "", - "room_id": TEST_ROOM_ID, - "content": {"history_visibility": "joined"}, - } - ) - room_state.append(history_visibility_evt) - test_store.add_event(history_visibility_evt) - - for i in range(0, 100000): - user = "@resident_user_%i:test.com" % (i,) - evt = FrozenEvent( - { - "event_id": "$res_event_%i" % (i,), - "type": "m.room.member", - "state_key": user, - "sender": user, - "room_id": TEST_ROOM_ID, - "content": {"membership": "join", "extra": "zzz,"}, - } - ) - room_state.append(evt) - test_store.add_event(evt) - - events_to_filter = [] - for i in range(0, 10): - user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") - evt = FrozenEvent( - { - "event_id": "$evt%i" % (i,), - "type": "m.room.member", - "state_key": user, - "sender": user, - "room_id": TEST_ROOM_ID, - "content": {"membership": "join", "extra": "zzz"}, - } - ) - events_to_filter.append(evt) - room_state.append(evt) - - test_store.add_event(evt) - test_store.set_state_ids_for_event( - evt, {(e.type, e.state_key): e.event_id for e in room_state} - ) - - pr = cProfile.Profile() - pr.enable() - - logger.info("Starting filtering") - start = time.time() - - storage = Mock() - storage.main = test_store - storage.state = test_store - - filtered = yield defer.ensureDeferred( - filter_events_for_server(test_store, "test_server", events_to_filter) - ) - logger.info("Filtering took %f seconds", time.time() - start) - - pr.disable() - with open("filter_events_for_server.profile", "w+") as f: - ps = pstats.Stats(pr, stream=f).sort_stats("cumulative") - ps.print_stats() - - # the result should be 5 redacted events, and 5 unredacted events. - for i in range(0, 5): - self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) - self.assertNotIn("extra", filtered[i].content) - - for i in range(5, 10): - self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) - self.assertEqual(filtered[i].content["extra"], "zzz") - - test_large_room.skip = "Disabled by default because it's slow" - - -class _TestStore: - """Implements a few methods of the DataStore, so that we can test - filter_events_for_server - - """ - - def __init__(self): - # data for get_events: a map from event_id to event - self.events = {} - - # data for get_state_ids_for_events mock: a map from event_id to - # a map from (type_state_key) -> event_id for the state at that - # event - self.state_ids_for_events = {} - - def add_event(self, event): - self.events[event.event_id] = event - - def set_state_ids_for_event(self, event, state): - self.state_ids_for_events[event.event_id] = state - - def get_state_ids_for_events(self, events, types): - res = {} - include_memberships = False - for (type, state_key) in types: - if type == "m.room.history_visibility": - continue - if type != "m.room.member" or state_key is not None: - raise RuntimeError( - "Unimplemented: get_state_ids with type (%s, %s)" - % (type, state_key) - ) - include_memberships = True - - if include_memberships: - for event_id in events: - res[event_id] = self.state_ids_for_events[event_id] - - else: - k = ("m.room.history_visibility", "") - for event_id in events: - hve = self.state_ids_for_events[event_id][k] - res[event_id] = {k: hve} - - return succeed(res) - - def get_events(self, events): - return succeed({event_id: self.events[event_id] for event_id in events}) - - def are_users_erased(self, users): - return succeed({u: False for u in users}) -- cgit 1.5.1 From 16d39a5490ce74c901c7a8dbb990c6e83c379207 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 2 Dec 2021 18:13:43 +0000 Subject: Clean up `tests.storage.test_main` to remove use of legacy code. (#11493) --- changelog.d/11493.misc | 1 + tests/storage/test_main.py | 27 ++++++++++----------------- 2 files changed, 11 insertions(+), 17 deletions(-) create mode 100644 changelog.d/11493.misc (limited to 'tests') diff --git a/changelog.d/11493.misc b/changelog.d/11493.misc new file mode 100644 index 0000000000..646584a0d1 --- /dev/null +++ b/changelog.d/11493.misc @@ -0,0 +1 @@ +Clean up `tests.storage.test_main` to remove use of legacy code. \ No newline at end of file diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index d2b7b89952..f8d11bac4e 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -13,42 +13,35 @@ # limitations under the License. -from twisted.internet import defer - from synapse.types import UserID from tests import unittest -from tests.utils import setup_test_homeserver -class DataStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) +class DataStoreTestCase(unittest.HomeserverTestCase): + def setUp(self) -> None: + super(DataStoreTestCase, self).setUp() - self.store = hs.get_datastore() + self.store = self.hs.get_datastore() self.user = UserID.from_string("@abcde:test") self.displayname = "Frank" - @defer.inlineCallbacks - def test_get_users_paginate(self): - yield defer.ensureDeferred( - self.store.register_user(self.user.to_string(), "pass") - ) - yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) - yield defer.ensureDeferred( + def test_get_users_paginate(self) -> None: + self.get_success(self.store.register_user(self.user.to_string(), "pass")) + self.get_success(self.store.create_profile(self.user.localpart)) + self.get_success( self.store.set_profile_displayname(self.user.localpart, self.displayname) ) - users, total = yield defer.ensureDeferred( + users, total = self.get_success( self.store.get_users_paginate(0, 10, name="bc", guests=False) ) self.assertEquals(1, total) self.assertEquals(self.displayname, users.pop()["displayname"]) - users, total = yield defer.ensureDeferred( + users, total = self.get_success( self.store.get_users_paginate(0, 10, name="BC", guests=False) ) -- cgit 1.5.1 From f91624a5950e14ba9007eed9bfa1c828676d4745 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Thu, 2 Dec 2021 18:43:33 +0000 Subject: Clean up tests.storage.test_appservice (#11492) --- changelog.d/11492.misc | 1 + tests/storage/test_appservice.py | 345 +++++++++++++++++++-------------------- 2 files changed, 169 insertions(+), 177 deletions(-) create mode 100644 changelog.d/11492.misc (limited to 'tests') diff --git a/changelog.d/11492.misc b/changelog.d/11492.misc new file mode 100644 index 0000000000..c14a7d2e98 --- /dev/null +++ b/changelog.d/11492.misc @@ -0,0 +1 @@ +Add type annotations to `tests.storage.test_appservice`. \ No newline at end of file diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 4b20a28ca2..329490caad 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -14,13 +14,12 @@ import json import os import tempfile -from typing import Any, Generator, List, Optional, cast +from typing import List, Optional, cast from unittest.mock import Mock import yaml from twisted.internet import defer -from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.appservice import ApplicationService, ApplicationServiceState @@ -36,19 +35,16 @@ from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable -from tests.utils import setup_test_homeserver -class ApplicationServiceStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks +class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): def setUp(self): + super(ApplicationServiceStoreTestCase, self).setUp() + self.as_yaml_files: List[str] = [] - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) - hs.config.appservice.app_service_config_files = self.as_yaml_files - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = self.as_yaml_files + self.hs.config.caches.event_cache_size = 1 self.as_token = "token1" self.as_url = "some_url" @@ -59,9 +55,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] self.store = ApplicationServiceStore( - database, make_conn(database._database_config, database.engine, "test"), hs + database, + make_conn(database._database_config, database.engine, "test"), + self.hs, ) def tearDown(self) -> None: @@ -72,6 +70,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): except Exception: pass + super(ApplicationServiceStoreTestCase, self).tearDown() + def _add_appservice(self, as_token, id, url, hs_token, sender) -> None: as_yaml = { "url": url, @@ -105,17 +105,13 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self.assertEquals(len(services), 3) -class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): +class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): + def setUp(self) -> None: + super(ApplicationServiceTransactionStoreTestCase, self).setUp() self.as_yaml_files: List[str] = [] - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) - - hs.config.appservice.app_service_config_files = self.as_yaml_files - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = self.as_yaml_files + self.hs.config.caches.event_cache_size = 1 self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, @@ -129,13 +125,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] # We assume there is only one database in these tests - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] self.db_pool = database._db_pool self.engine = database.engine - db_config = hs.config.database.get_single_database() + db_config = self.hs.config.database.get_single_database() self.store = TestTransactionStore( - database, make_conn(db_config, self.engine, "test"), hs + database, make_conn(db_config, self.engine, "test"), self.hs ) def _add_service(self, url, as_token, id) -> None: @@ -181,221 +177,223 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): (as_id, txn_id, ApplicationServiceState.UP.value), ) - @defer.inlineCallbacks def test_get_appservice_state_none( self, - ) -> Generator["Deferred[object]", object, None]: + ) -> None: service = Mock(id="999") - state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) + state = self.get_success(self.store.get_appservice_state(service)) self.assertEquals(None, state) - @defer.inlineCallbacks def test_get_appservice_state_up( self, - ) -> Generator["Deferred[object]", object, None]: - yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) + ) -> None: + self.get_success( + self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) + ) service = Mock(id=self.as_list[0]["id"]) - state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) + state = self.get_success( + defer.ensureDeferred(self.store.get_appservice_state(service)) + ) self.assertEquals(ApplicationServiceState.UP, state) - @defer.inlineCallbacks def test_get_appservice_state_down( self, - ) -> Generator["Deferred[object]", object, None]: - yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) - yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) - yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) + ) -> None: + self.get_success( + self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) + ) + self.get_success( + self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) + ) + self.get_success( + self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) + ) service = Mock(id=self.as_list[1]["id"]) - state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) + state = self.get_success(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) - @defer.inlineCallbacks def test_get_appservices_by_state_none( self, - ) -> Generator["Deferred[object]", Any, None]: - services = yield defer.ensureDeferred( + ) -> None: + services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) - @defer.inlineCallbacks def test_set_appservices_state_down( self, - ) -> Generator["Deferred[object]", Any, None]: + ) -> None: service = Mock(id=self.as_list[1]["id"]) - yield defer.ensureDeferred( + self.get_success( self.store.set_appservice_state(service, ApplicationServiceState.DOWN) ) - rows = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT as_id FROM application_services_state WHERE state=?" - ), - (ApplicationServiceState.DOWN.value,), + rows = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), + (ApplicationServiceState.DOWN.value,), + ) ) self.assertEquals(service.id, rows[0][0]) - @defer.inlineCallbacks def test_set_appservices_state_multiple_up( self, - ) -> Generator["Deferred[object]", Any, None]: + ) -> None: service = Mock(id=self.as_list[1]["id"]) - yield defer.ensureDeferred( + self.get_success( self.store.set_appservice_state(service, ApplicationServiceState.UP) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_appservice_state(service, ApplicationServiceState.DOWN) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_appservice_state(service, ApplicationServiceState.UP) ) - rows = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT as_id FROM application_services_state WHERE state=?" - ), - (ApplicationServiceState.UP.value,), + rows = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT as_id FROM application_services_state WHERE state=?" + ), + (ApplicationServiceState.UP.value,), + ) ) self.assertEquals(service.id, rows[0][0]) - @defer.inlineCallbacks def test_create_appservice_txn_first( self, - ) -> Generator["Deferred[object]", Any, None]: + ) -> None: service = Mock(id=self.as_list[0]["id"]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events, []) + txn = self.get_success( + defer.ensureDeferred(self.store.create_appservice_txn(service, events, [])) ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - @defer.inlineCallbacks def test_create_appservice_txn_older_last_txn( self, - ) -> Generator["Deferred[object]", Any, None]: + ) -> None: service = Mock(id=self.as_list[0]["id"]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - yield self._set_last_txn(service.id, 9643) # AS is falling behind - yield self._insert_txn(service.id, 9644, events) - yield self._insert_txn(service.id, 9645, events) - txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events, []) - ) + self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind + self.get_success(self._insert_txn(service.id, 9644, events)) + self.get_success(self._insert_txn(service.id, 9645, events)) + txn = self.get_success(self.store.create_appservice_txn(service, events, [])) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - @defer.inlineCallbacks def test_create_appservice_txn_up_to_date_last_txn( self, - ) -> Generator["Deferred[object]", Any, None]: + ) -> None: service = Mock(id=self.as_list[0]["id"]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - yield self._set_last_txn(service.id, 9643) - txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events, []) - ) + self.get_success(self._set_last_txn(service.id, 9643)) + txn = self.get_success(self.store.create_appservice_txn(service, events, [])) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - @defer.inlineCallbacks def test_create_appservice_txn_up_fuzzing( self, - ) -> Generator["Deferred[object]", Any, None]: + ) -> None: service = Mock(id=self.as_list[0]["id"]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - yield self._set_last_txn(service.id, 9643) + self.get_success(self._set_last_txn(service.id, 9643)) # dump in rows with higher IDs to make sure the queries aren't wrong. - yield self._set_last_txn(self.as_list[1]["id"], 119643) - yield self._set_last_txn(self.as_list[2]["id"], 9) - yield self._set_last_txn(self.as_list[3]["id"], 9643) - yield self._insert_txn(self.as_list[1]["id"], 119644, events) - yield self._insert_txn(self.as_list[1]["id"], 119645, events) - yield self._insert_txn(self.as_list[1]["id"], 119646, events) - yield self._insert_txn(self.as_list[2]["id"], 10, events) - yield self._insert_txn(self.as_list[3]["id"], 9643, events) - - txn = yield defer.ensureDeferred( - self.store.create_appservice_txn(service, events, []) - ) + self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643)) + self.get_success(self._set_last_txn(self.as_list[2]["id"], 9)) + self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643)) + self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events)) + self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events)) + self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events)) + self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events)) + self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) + + txn = self.get_success(self.store.create_appservice_txn(service, events, [])) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) - @defer.inlineCallbacks def test_complete_appservice_txn_first_txn( self, - ) -> Generator["Deferred[object]", Any, None]: + ) -> None: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 1 - yield self._insert_txn(service.id, txn_id, events) - yield defer.ensureDeferred( + self.get_success(self._insert_txn(service.id, txn_id, events)) + self.get_success( self.store.complete_appservice_txn(txn_id=txn_id, service=service) ) - res = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT last_txn FROM application_services_state WHERE as_id=?" - ), - (service.id,), + res = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT last_txn FROM application_services_state WHERE as_id=?" + ), + (service.id,), + ) ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) - res = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT * FROM application_services_txns WHERE txn_id=?" - ), - (txn_id,), + res = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), + ) ) self.assertEquals(0, len(res)) - @defer.inlineCallbacks def test_complete_appservice_txn_existing_in_state_table( self, - ) -> Generator["Deferred[object]", Any, None]: + ) -> None: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 5 - yield self._set_last_txn(service.id, 4) - yield self._insert_txn(service.id, txn_id, events) - yield defer.ensureDeferred( + self.get_success(self._set_last_txn(service.id, 4)) + self.get_success(self._insert_txn(service.id, txn_id, events)) + self.get_success( self.store.complete_appservice_txn(txn_id=txn_id, service=service) ) - res = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT last_txn, state FROM application_services_state WHERE as_id=?" - ), - (service.id,), + res = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT last_txn, state FROM application_services_state WHERE as_id=?" + ), + (service.id,), + ) ) self.assertEquals(1, len(res)) self.assertEquals(txn_id, res[0][0]) self.assertEquals(ApplicationServiceState.UP.value, res[0][1]) - res = yield self.db_pool.runQuery( - self.engine.convert_param_style( - "SELECT * FROM application_services_txns WHERE txn_id=?" - ), - (txn_id,), + res = self.get_success( + self.db_pool.runQuery( + self.engine.convert_param_style( + "SELECT * FROM application_services_txns WHERE txn_id=?" + ), + (txn_id,), + ) ) self.assertEquals(0, len(res)) - @defer.inlineCallbacks def test_get_oldest_unsent_txn_none( self, - ) -> Generator["Deferred[object]", Any, None]: + ) -> None: service = Mock(id=self.as_list[0]["id"]) - txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) + txn = self.get_success(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) - @defer.inlineCallbacks - def test_get_oldest_unsent_txn(self) -> Generator["Deferred[object]", Any, None]: + def test_get_oldest_unsent_txn(self) -> None: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] other_events = [Mock(event_id="e5"), Mock(event_id="e6")] @@ -404,39 +402,49 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # (ignore needed because Mypy won't allow us to assign to a method otherwise) self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] - yield self._insert_txn(self.as_list[1]["id"], 9, other_events) - yield self._insert_txn(service.id, 10, events) - yield self._insert_txn(service.id, 11, other_events) - yield self._insert_txn(service.id, 12, other_events) + self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) + self.get_success(self._insert_txn(service.id, 10, events)) + self.get_success(self._insert_txn(service.id, 11, other_events)) + self.get_success(self._insert_txn(service.id, 12, other_events)) - txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) + txn = self.get_success(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) - @defer.inlineCallbacks def test_get_appservices_by_state_single( self, - ) -> Generator["Deferred[object]", Any, None]: - yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) - yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) + ) -> None: + self.get_success( + self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + ) + self.get_success( + self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) + ) - services = yield defer.ensureDeferred( + services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) - @defer.inlineCallbacks def test_get_appservices_by_state_multiple( self, - ) -> Generator["Deferred[object]", Any, None]: - yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) - yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) - yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) + ) -> None: + self.get_success( + self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) + ) + self.get_success( + self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) + ) + self.get_success( + self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) + ) + self.get_success( + self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) + ) - services = yield defer.ensureDeferred( + services = self.get_success( self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) @@ -447,10 +455,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock) -> "HomeServer": - hs = self.setup_test_homeserver() - return hs - def prepare( self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: @@ -512,7 +516,7 @@ class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServic super().__init__(database, db_conn, hs) -class ApplicationServiceStoreConfigTestCase(unittest.TestCase): +class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): def _write_config(self, suffix, **kwargs) -> str: vals = { "id": "id" + suffix, @@ -529,41 +533,33 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): f.write(yaml.dump(vals)) return path - @defer.inlineCallbacks - def test_unique_works(self) -> Generator["Deferred[object]", Any, None]: + def test_unique_works(self) -> None: f1 = self._write_config(suffix="1") f2 = self._write_config(suffix="2") - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) - - hs.config.appservice.app_service_config_files = [f1, f2] - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = [f1, f2] + self.hs.config.caches.event_cache_size = 1 - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] ApplicationServiceStore( - database, make_conn(database._database_config, database.engine, "test"), hs + database, + make_conn(database._database_config, database.engine, "test"), + self.hs, ) - @defer.inlineCallbacks - def test_duplicate_ids(self) -> Generator["Deferred[object]", Any, None]: + def test_duplicate_ids(self) -> None: f1 = self._write_config(id="id", suffix="1") f2 = self._write_config(id="id", suffix="2") - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) - - hs.config.appservice.app_service_config_files = [f1, f2] - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = [f1, f2] + self.hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] ApplicationServiceStore( database, make_conn(database._database_config, database.engine, "test"), - hs, + self.hs, ) e = cm.exception @@ -571,24 +567,19 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): self.assertIn(f2, str(e)) self.assertIn("id", str(e)) - @defer.inlineCallbacks - def test_duplicate_as_tokens(self) -> Generator["Deferred[object]", Any, None]: + def test_duplicate_as_tokens(self) -> None: f1 = self._write_config(as_token="as_token", suffix="1") f2 = self._write_config(as_token="as_token", suffix="2") - hs = yield setup_test_homeserver( - self.addCleanup, federation_sender=Mock(), federation_client=Mock() - ) - - hs.config.appservice.app_service_config_files = [f1, f2] - hs.config.caches.event_cache_size = 1 + self.hs.config.appservice.app_service_config_files = [f1, f2] + self.hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: - database = hs.get_datastores().databases[0] + database = self.hs.get_datastores().databases[0] ApplicationServiceStore( database, make_conn(database._database_config, database.engine, "test"), - hs, + self.hs, ) e = cm.exception -- cgit 1.5.1 From f7ec6e7d9e0dc360d9fb41f3a1afd7bdba1475c7 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 3 Dec 2021 11:35:24 +0000 Subject: Convert one of the `setup_test_homeserver`s to `make_test_homeserver_synchronous` and pass in the homeserver rather than calling a same-named function to ask for one. Later commits will jiggle things around to make this sensible. --- tests/server.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) (limited to 'tests') diff --git a/tests/server.py b/tests/server.py index 40cf5b12c3..41eb3995bd 100644 --- a/tests/server.py +++ b/tests/server.py @@ -57,7 +57,6 @@ from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import Clock -from tests.utils import setup_test_homeserver as _sth logger = logging.getLogger(__name__) @@ -450,14 +449,11 @@ class ThreadPool: return d -def setup_test_homeserver(cleanup_func, *args, **kwargs): +def make_test_homeserver_synchronous(server: HomeServer) -> None: """ - Set up a synchronous test server, driven by the reactor used by - the homeserver. + Make the given test homeserver's database interactions synchronous. """ - server = _sth(cleanup_func, *args, **kwargs) - # Make the thread pool synchronous. clock = server.get_clock() for database in server.get_datastores().databases: @@ -485,6 +481,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction + # Replace the thread pool with a threadless 'thread' pool pool.threadpool = ThreadPool(clock._reactor) pool.running = True @@ -492,8 +489,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): # thread, so we need to disable the dedicated thread behaviour. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False - return server - def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() -- cgit 1.5.1 From b3fd99b74a3f6f42a9afd1b19ee4c60e38e8e91a Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 3 Dec 2021 11:37:21 +0000 Subject: Move `tests.utils.setup_test_homeserver` to `tests.server` It had no users. We have just taken the identity of a previous function but don't provide the same behaviour, so we need to fix this in the next commit... --- tests/server.py | 185 ++++++++++++++++++++++++++++++++++++++- tests/storage/test_base.py | 3 +- tests/storage/test_roommember.py | 2 +- tests/utils.py | 175 +----------------------------------- 4 files changed, 188 insertions(+), 177 deletions(-) (limited to 'tests') diff --git a/tests/server.py b/tests/server.py index 41eb3995bd..017e5cf635 100644 --- a/tests/server.py +++ b/tests/server.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import hashlib import json import logging +import time +import uuid +import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( @@ -27,6 +30,7 @@ from typing import ( Type, Union, ) +from unittest.mock import Mock import attr from typing_extensions import Deque @@ -53,10 +57,24 @@ from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site +from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest +from synapse.server import HomeServer +from synapse.storage import DataStore +from synapse.storage.engines import PostgresEngine, create_engine from synapse.types import JsonDict from synapse.util import Clock +from tests.utils import ( + LEAVE_DB, + POSTGRES_BASE_DB, + POSTGRES_HOST, + POSTGRES_PASSWORD, + POSTGRES_USER, + USE_POSTGRES_FOR_TESTS, + MockClock, + default_config, +) logger = logging.getLogger(__name__) @@ -668,3 +686,168 @@ def connect_client( client.makeConnection(FakeTransport(server, reactor)) return client, server + + +class TestHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + + +def setup_test_homeserver( + cleanup_func, + name="test", + config=None, + reactor=None, + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs, +): + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. + + Calling this method directly is deprecated: you should instead derive from + HomeserverTestCase. + """ + if reactor is None: + from twisted.internet import reactor + + if config is None: + config = default_config(name, parse=True) + + config.ldap_enabled = False + + if "clock" not in kwargs: + kwargs["clock"] = MockClock() + + if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + + database_config = { + "name": "psycopg2", + "args": { + "database": test_db, + "host": POSTGRES_HOST, + "password": POSTGRES_PASSWORD, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + else: + database_config = { + "name": "sqlite3", + "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, + } + + if "db_txn_limit" in kwargs: + database_config["txn_limit"] = kwargs["db_txn_limit"] + + database = DatabaseConnectionConfig("master", database_config) + config.database.databases = [database] + + db_engine = create_engine(database.config) + + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + + hs = homeserver_to_use( + name, + config=config, + version_string="Synapse/tests", + reactor=reactor, + ) + + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, "_" + key, val) + + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + + hs.setup() + if homeserver_to_use == TestHomeServer: + hs.setup_background_tasks() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] + + # We need to do cleanup on PostgreSQL + def cleanup(): + import psycopg2 + + # Close all the db pools + database._db_pool.close() + + dropped = False + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for _ in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + + cur.close() + db_conn.close() + + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) + + # bcrypt is far too slow to be doing in unit tests + # Need to let the HS build an auth handler and then mess with it + # because AuthHandler's constructor requires the HS, so we can't make one + # beforehand and pass it in to the HS's constructor (chicken / egg) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash + + return hs diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index ddad44bd6c..3e4f0579c9 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -23,7 +23,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import create_engine from tests import unittest -from tests.utils import TestHomeServer, default_config +from tests.server import TestHomeServer +from tests.utils import default_config class SQLBaseStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index fccab733c0..5cfdfe9b85 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -19,8 +19,8 @@ from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests import unittest +from tests.server import TestHomeServer from tests.test_utils import event_injection -from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): diff --git a/tests/utils.py b/tests/utils.py index 983859120f..6d013e8518 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,12 +14,7 @@ # limitations under the License. import atexit -import hashlib import os -import time -import uuid -import warnings -from typing import Type from unittest.mock import Mock, patch from urllib import parse as urlparse @@ -28,14 +23,11 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions -from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context -from synapse.server import HomeServer -from synapse.storage import DataStore from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import PostgresEngine, create_engine +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database # set this to True to run the tests against postgres instead of sqlite. @@ -182,171 +174,6 @@ def default_config(name, parse=False): return config_dict -class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore - - -def setup_test_homeserver( - cleanup_func, - name="test", - config=None, - reactor=None, - homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs, -): - """ - Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. - - If no datastore is supplied, one is created and given to the homeserver. - - Args: - cleanup_func : The function used to register a cleanup routine for - after the test. - - Calling this method directly is deprecated: you should instead derive from - HomeserverTestCase. - """ - if reactor is None: - from twisted.internet import reactor - - if config is None: - config = default_config(name, parse=True) - - config.ldap_enabled = False - - if "clock" not in kwargs: - kwargs["clock"] = MockClock() - - if USE_POSTGRES_FOR_TESTS: - test_db = "synapse_test_%s" % uuid.uuid4().hex - - database_config = { - "name": "psycopg2", - "args": { - "database": test_db, - "host": POSTGRES_HOST, - "password": POSTGRES_PASSWORD, - "user": POSTGRES_USER, - "cp_min": 1, - "cp_max": 5, - }, - } - else: - database_config = { - "name": "sqlite3", - "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, - } - - if "db_txn_limit" in kwargs: - database_config["txn_limit"] = kwargs["db_txn_limit"] - - database = DatabaseConnectionConfig("master", database_config) - config.database.databases = [database] - - db_engine = create_engine(database.config) - - # Create the database before we actually try and connect to it, based off - # the template database we generate in setupdb() - if isinstance(db_engine, PostgresEngine): - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - cur.execute( - "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) - ) - cur.close() - db_conn.close() - - hs = homeserver_to_use( - name, - config=config, - version_string="Synapse/tests", - reactor=reactor, - ) - - # Install @cache_in_self attributes - for key, val in kwargs.items(): - setattr(hs, "_" + key, val) - - # Mock TLS - hs.tls_server_context_factory = Mock() - hs.tls_client_options_factory = Mock() - - hs.setup() - if homeserver_to_use == TestHomeServer: - hs.setup_background_tasks() - - if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] - - # We need to do cleanup on PostgreSQL - def cleanup(): - import psycopg2 - - # Close all the db pools - database._db_pool.close() - - dropped = False - - # Drop the test database - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - - # Try a few times to drop the DB. Some things may hold on to the - # database for a few more seconds due to flakiness, preventing - # us from dropping it when the test is over. If we can't drop - # it, warn and move on. - for _ in range(5): - try: - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() - dropped = True - except psycopg2.OperationalError as e: - warnings.warn( - "Couldn't drop old db: " + str(e), category=UserWarning - ) - time.sleep(0.5) - - cur.close() - db_conn.close() - - if not dropped: - warnings.warn("Failed to drop old DB.", category=UserWarning) - - if not LEAVE_DB: - # Register the cleanup hook - cleanup_func(cleanup) - - # bcrypt is far too slow to be doing in unit tests - # Need to let the HS build an auth handler and then mess with it - # because AuthHandler's constructor requires the HS, so we can't make one - # beforehand and pass it in to the HS's constructor (chicken / egg) - async def hash(p): - return hashlib.md5(p.encode("utf8")).hexdigest() - - hs.get_auth_handler().hash = hash - - async def validate_hash(p, h): - return hashlib.md5(p.encode("utf8")).hexdigest() == h - - hs.get_auth_handler().validate_hash = validate_hash - - return hs - - def mock_getRawHeaders(headers=None): headers = headers if headers is not None else {} -- cgit 1.5.1 From 7be88fbf48156b36b6daefb228e1258e7d48cae4 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 3 Dec 2021 11:40:05 +0000 Subject: Give `tests.server.setup_test_homeserver` (nominally!) the same behaviour by calling into `make_test_homeserver_synchronous`. The function *could* have been inlined at this point but the function is big enough and it felt fine to leave it as is. At least there isn't a confusing name clash anymore! --- tests/server.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'tests') diff --git a/tests/server.py b/tests/server.py index 017e5cf635..b29df37595 100644 --- a/tests/server.py +++ b/tests/server.py @@ -850,4 +850,7 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash + # Make the threadpool and database transactions synchronous for testing. + make_test_homeserver_synchronous(hs) + return hs -- cgit 1.5.1 From 8cd68b8102eeab1b525712097c1b2e9679c11896 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 3 Dec 2021 12:31:28 +0000 Subject: Revert accidental commits to develop. --- changelog.d/11503.misc | 1 - tests/server.py | 199 ++------------------------------------- tests/storage/test_base.py | 3 +- tests/storage/test_roommember.py | 2 +- tests/utils.py | 175 +++++++++++++++++++++++++++++++++- 5 files changed, 185 insertions(+), 195 deletions(-) delete mode 100644 changelog.d/11503.misc (limited to 'tests') diff --git a/changelog.d/11503.misc b/changelog.d/11503.misc deleted file mode 100644 index 03a24a9224..0000000000 --- a/changelog.d/11503.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`. \ No newline at end of file diff --git a/tests/server.py b/tests/server.py index b29df37595..40cf5b12c3 100644 --- a/tests/server.py +++ b/tests/server.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import hashlib + import json import logging -import time -import uuid -import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( @@ -30,7 +27,6 @@ from typing import ( Type, Union, ) -from unittest.mock import Mock import attr from typing_extensions import Deque @@ -57,24 +53,11 @@ from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site -from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest -from synapse.server import HomeServer -from synapse.storage import DataStore -from synapse.storage.engines import PostgresEngine, create_engine from synapse.types import JsonDict from synapse.util import Clock -from tests.utils import ( - LEAVE_DB, - POSTGRES_BASE_DB, - POSTGRES_HOST, - POSTGRES_PASSWORD, - POSTGRES_USER, - USE_POSTGRES_FOR_TESTS, - MockClock, - default_config, -) +from tests.utils import setup_test_homeserver as _sth logger = logging.getLogger(__name__) @@ -467,11 +450,14 @@ class ThreadPool: return d -def make_test_homeserver_synchronous(server: HomeServer) -> None: +def setup_test_homeserver(cleanup_func, *args, **kwargs): """ - Make the given test homeserver's database interactions synchronous. + Set up a synchronous test server, driven by the reactor used by + the homeserver. """ + server = _sth(cleanup_func, *args, **kwargs) + # Make the thread pool synchronous. clock = server.get_clock() for database in server.get_datastores().databases: @@ -499,7 +485,6 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction - # Replace the thread pool with a threadless 'thread' pool pool.threadpool = ThreadPool(clock._reactor) pool.running = True @@ -507,6 +492,8 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None: # thread, so we need to disable the dedicated thread behaviour. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + return server + def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() @@ -686,171 +673,3 @@ def connect_client( client.makeConnection(FakeTransport(server, reactor)) return client, server - - -class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore - - -def setup_test_homeserver( - cleanup_func, - name="test", - config=None, - reactor=None, - homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs, -): - """ - Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. - - If no datastore is supplied, one is created and given to the homeserver. - - Args: - cleanup_func : The function used to register a cleanup routine for - after the test. - - Calling this method directly is deprecated: you should instead derive from - HomeserverTestCase. - """ - if reactor is None: - from twisted.internet import reactor - - if config is None: - config = default_config(name, parse=True) - - config.ldap_enabled = False - - if "clock" not in kwargs: - kwargs["clock"] = MockClock() - - if USE_POSTGRES_FOR_TESTS: - test_db = "synapse_test_%s" % uuid.uuid4().hex - - database_config = { - "name": "psycopg2", - "args": { - "database": test_db, - "host": POSTGRES_HOST, - "password": POSTGRES_PASSWORD, - "user": POSTGRES_USER, - "cp_min": 1, - "cp_max": 5, - }, - } - else: - database_config = { - "name": "sqlite3", - "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, - } - - if "db_txn_limit" in kwargs: - database_config["txn_limit"] = kwargs["db_txn_limit"] - - database = DatabaseConnectionConfig("master", database_config) - config.database.databases = [database] - - db_engine = create_engine(database.config) - - # Create the database before we actually try and connect to it, based off - # the template database we generate in setupdb() - if isinstance(db_engine, PostgresEngine): - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - cur.execute( - "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) - ) - cur.close() - db_conn.close() - - hs = homeserver_to_use( - name, - config=config, - version_string="Synapse/tests", - reactor=reactor, - ) - - # Install @cache_in_self attributes - for key, val in kwargs.items(): - setattr(hs, "_" + key, val) - - # Mock TLS - hs.tls_server_context_factory = Mock() - hs.tls_client_options_factory = Mock() - - hs.setup() - if homeserver_to_use == TestHomeServer: - hs.setup_background_tasks() - - if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] - - # We need to do cleanup on PostgreSQL - def cleanup(): - import psycopg2 - - # Close all the db pools - database._db_pool.close() - - dropped = False - - # Drop the test database - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - - # Try a few times to drop the DB. Some things may hold on to the - # database for a few more seconds due to flakiness, preventing - # us from dropping it when the test is over. If we can't drop - # it, warn and move on. - for _ in range(5): - try: - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() - dropped = True - except psycopg2.OperationalError as e: - warnings.warn( - "Couldn't drop old db: " + str(e), category=UserWarning - ) - time.sleep(0.5) - - cur.close() - db_conn.close() - - if not dropped: - warnings.warn("Failed to drop old DB.", category=UserWarning) - - if not LEAVE_DB: - # Register the cleanup hook - cleanup_func(cleanup) - - # bcrypt is far too slow to be doing in unit tests - # Need to let the HS build an auth handler and then mess with it - # because AuthHandler's constructor requires the HS, so we can't make one - # beforehand and pass it in to the HS's constructor (chicken / egg) - async def hash(p): - return hashlib.md5(p.encode("utf8")).hexdigest() - - hs.get_auth_handler().hash = hash - - async def validate_hash(p, h): - return hashlib.md5(p.encode("utf8")).hexdigest() == h - - hs.get_auth_handler().validate_hash = validate_hash - - # Make the threadpool and database transactions synchronous for testing. - make_test_homeserver_synchronous(hs) - - return hs diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 3e4f0579c9..ddad44bd6c 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -23,8 +23,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import create_engine from tests import unittest -from tests.server import TestHomeServer -from tests.utils import default_config +from tests.utils import TestHomeServer, default_config class SQLBaseStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5cfdfe9b85..fccab733c0 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -19,8 +19,8 @@ from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests import unittest -from tests.server import TestHomeServer from tests.test_utils import event_injection +from tests.utils import TestHomeServer class RoomMemberStoreTestCase(unittest.HomeserverTestCase): diff --git a/tests/utils.py b/tests/utils.py index 6d013e8518..983859120f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,12 @@ # limitations under the License. import atexit +import hashlib import os +import time +import uuid +import warnings +from typing import Type from unittest.mock import Mock, patch from urllib import parse as urlparse @@ -23,11 +28,14 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context +from synapse.server import HomeServer +from synapse.storage import DataStore from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import create_engine +from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database # set this to True to run the tests against postgres instead of sqlite. @@ -174,6 +182,171 @@ def default_config(name, parse=False): return config_dict +class TestHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + + +def setup_test_homeserver( + cleanup_func, + name="test", + config=None, + reactor=None, + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs, +): + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. + + Calling this method directly is deprecated: you should instead derive from + HomeserverTestCase. + """ + if reactor is None: + from twisted.internet import reactor + + if config is None: + config = default_config(name, parse=True) + + config.ldap_enabled = False + + if "clock" not in kwargs: + kwargs["clock"] = MockClock() + + if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + + database_config = { + "name": "psycopg2", + "args": { + "database": test_db, + "host": POSTGRES_HOST, + "password": POSTGRES_PASSWORD, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + else: + database_config = { + "name": "sqlite3", + "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, + } + + if "db_txn_limit" in kwargs: + database_config["txn_limit"] = kwargs["db_txn_limit"] + + database = DatabaseConnectionConfig("master", database_config) + config.database.databases = [database] + + db_engine = create_engine(database.config) + + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + + hs = homeserver_to_use( + name, + config=config, + version_string="Synapse/tests", + reactor=reactor, + ) + + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, "_" + key, val) + + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + + hs.setup() + if homeserver_to_use == TestHomeServer: + hs.setup_background_tasks() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] + + # We need to do cleanup on PostgreSQL + def cleanup(): + import psycopg2 + + # Close all the db pools + database._db_pool.close() + + dropped = False + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for _ in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + + cur.close() + db_conn.close() + + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) + + # bcrypt is far too slow to be doing in unit tests + # Need to let the HS build an auth handler and then mess with it + # because AuthHandler's constructor requires the HS, so we can't make one + # beforehand and pass it in to the HS's constructor (chicken / egg) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash + + return hs + + def mock_getRawHeaders(headers=None): headers = headers if headers is not None else {} -- cgit 1.5.1 From e5f426cd54609e7f05f8241d845e6e36c5f10d9a Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 3 Dec 2021 14:57:13 +0100 Subject: Add type hints to `synapse/tests/rest/admin` (#11501) --- changelog.d/11501.misc | 1 + mypy.ini | 3 - tests/rest/admin/test_background_updates.py | 21 +++--- tests/rest/admin/test_device.py | 55 ++++++++-------- tests/rest/admin/test_event_reports.py | 61 ++++++++++-------- tests/rest/admin/test_media.py | 69 ++++++++++---------- tests/rest/admin/test_registration_tokens.py | 77 +++++++++++----------- tests/rest/admin/test_room.py | 95 +++++++++++++++------------- tests/rest/admin/test_server_notice.py | 27 ++++---- tests/rest/admin/test_statistics.py | 40 ++++++------ tests/rest/admin/test_user.py | 36 +++++------ 11 files changed, 257 insertions(+), 228 deletions(-) create mode 100644 changelog.d/11501.misc (limited to 'tests') diff --git a/changelog.d/11501.misc b/changelog.d/11501.misc new file mode 100644 index 0000000000..40e01194df --- /dev/null +++ b/changelog.d/11501.misc @@ -0,0 +1 @@ +Add type hints to `synapse/tests/rest/admin`. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index fea71d154c..1caf807e85 100644 --- a/mypy.ini +++ b/mypy.ini @@ -86,9 +86,6 @@ exclude = (?x) |tests/push/test_presentable_names.py |tests/push/test_push_rule_evaluator.py |tests/rest/admin/test_admin.py - |tests/rest/admin/test_device.py - |tests/rest/admin/test_media.py - |tests/rest/admin/test_server_notice.py |tests/rest/admin/test_user.py |tests/rest/admin/test_username_available.py |tests/rest/client/test_account.py diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index a5423af652..4d152c0d66 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -16,11 +16,14 @@ from typing import Collection from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login from synapse.server import HomeServer from synapse.storage.background_updates import BackgroundUpdater +from synapse.util import Clock from tests import unittest @@ -31,7 +34,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -44,7 +47,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ("POST", "/_synapse/admin/v1/background_updates/start_job"), ] ) - def test_requester_is_no_admin(self, method: str, url: str): + def test_requester_is_no_admin(self, method: str, url: str) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -62,7 +65,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -90,7 +93,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - def _register_bg_update(self): + def _register_bg_update(self) -> None: "Adds a bg update but doesn't start it" async def _fake_update(progress, batch_size) -> int: @@ -112,7 +115,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) ) - def test_status_empty(self): + def test_status_empty(self) -> None: """Test the status API works.""" channel = self.make_request( @@ -127,7 +130,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): channel.json_body, {"current_updates": {}, "enabled": True} ) - def test_status_bg_update(self): + def test_status_bg_update(self) -> None: """Test the status API works with a background update.""" # Create a new background update @@ -162,7 +165,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): }, ) - def test_enabled(self): + def test_enabled(self) -> None: """Test the enabled API works.""" # Create a new background update @@ -299,7 +302,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ), ] ) - def test_start_backround_job(self, job_name: str, updates: Collection[str]): + def test_start_backround_job(self, job_name: str, updates: Collection[str]) -> None: """ Test that background updates add to database and be processed. @@ -341,7 +344,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): ) ) - def test_start_backround_job_twice(self): + def test_start_backround_job_twice(self) -> None: """Test that add a background update twice return an error.""" # add job to database diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index baff057c56..f7080bda87 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import urllib.parse from http import HTTPStatus from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -31,7 +34,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -48,7 +51,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_no_auth(self, method: str): + def test_no_auth(self, method: str) -> None: """ Try to get a device of an user without authentication. """ @@ -62,7 +65,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_requester_is_no_admin(self, method: str): + def test_requester_is_no_admin(self, method: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -80,7 +83,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_does_not_exist(self, method: str): + def test_user_does_not_exist(self, method: str) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -99,7 +102,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @parameterized.expand(["GET", "PUT", "DELETE"]) - def test_user_is_not_local(self, method: str): + def test_user_is_not_local(self, method: str) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -117,7 +120,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_device(self): + def test_unknown_device(self) -> None: """ Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. """ @@ -151,7 +154,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): # Delete unknown device returns status HTTPStatus.OK self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_update_device_too_long_display_name(self): + def test_update_device_too_long_display_name(self) -> None: """ Update a device with a display name that is invalid (too long). """ @@ -189,7 +192,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_no_display_name(self): + def test_update_no_display_name(self) -> None: """ Tests that a update for a device without JSON returns a HTTPStatus.OK """ @@ -219,7 +222,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) - def test_update_display_name(self): + def test_update_display_name(self) -> None: """ Tests a normal successful update of display name """ @@ -243,7 +246,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) - def test_get_device(self): + def test_get_device(self) -> None: """ Tests that a normal lookup for a device is successfully """ @@ -262,7 +265,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertIn("last_seen_ip", channel.json_body) self.assertIn("last_seen_ts", channel.json_body) - def test_delete_device(self): + def test_delete_device(self) -> None: """ Tests that a remove of a device is successfully """ @@ -292,7 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -302,7 +305,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to list devices of an user without authentication. """ @@ -315,7 +318,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -334,7 +337,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -348,7 +351,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -363,7 +366,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_user_has_no_devices(self): + def test_user_has_no_devices(self) -> None: """ Tests that a normal lookup for devices is successfully if user has no devices @@ -380,7 +383,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["devices"])) - def test_get_devices(self): + def test_get_devices(self) -> None: """ Tests that a normal lookup for devices is successfully """ @@ -416,7 +419,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_device_handler() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -428,7 +431,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete devices of an user without authentication. """ @@ -441,7 +444,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -460,7 +463,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -474,7 +477,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -489,7 +492,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - def test_unknown_devices(self): + def test_unknown_devices(self) -> None: """ Tests that a remove of a device that does not exist returns HTTPStatus.OK. """ @@ -503,7 +506,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): # Delete unknown devices returns status HTTPStatus.OK self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_delete_devices(self): + def test_delete_devices(self) -> None: """ Tests that a remove of devices is successfully """ diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index a9c46ec62d..4f89f8b534 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from http import HTTPStatus +from typing import List + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, report_event, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest @@ -29,7 +34,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -70,7 +75,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/event_reports" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to get an event report without authentication. """ @@ -83,7 +88,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -101,7 +106,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self): + def test_default_success(self) -> None: """ Testing list of reported events """ @@ -118,7 +123,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit(self): + def test_limit(self) -> None: """ Testing list of reported events with limit """ @@ -135,7 +140,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["event_reports"]) - def test_from(self): + def test_from(self) -> None: """ Testing list of reported events with a defined starting point (from) """ @@ -152,7 +157,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["event_reports"]) - def test_limit_and_from(self): + def test_limit_and_from(self) -> None: """ Testing list of reported events with a defined starting point and limit """ @@ -169,7 +174,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["event_reports"]), 10) self._check_fields(channel.json_body["event_reports"]) - def test_filter_room(self): + def test_filter_room(self) -> None: """ Testing list of reported events with a filter of room """ @@ -189,7 +194,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): for report in channel.json_body["event_reports"]: self.assertEqual(report["room_id"], self.room_id1) - def test_filter_user(self): + def test_filter_user(self) -> None: """ Testing list of reported events with a filter of user """ @@ -209,7 +214,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): for report in channel.json_body["event_reports"]: self.assertEqual(report["user_id"], self.other_user) - def test_filter_user_and_room(self): + def test_filter_user_and_room(self) -> None: """ Testing list of reported events with a filter of user and room """ @@ -230,7 +235,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(report["user_id"], self.other_user) self.assertEqual(report["room_id"], self.room_id1) - def test_valid_search_order(self): + def test_valid_search_order(self) -> None: """ Testing search order. Order by timestamps. """ @@ -271,7 +276,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) report += 1 - def test_invalid_search_order(self): + def test_invalid_search_order(self) -> None: """ Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST """ @@ -290,7 +295,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual("Unknown direction: bar", channel.json_body["error"]) - def test_limit_is_negative(self): + def test_limit_is_negative(self) -> None: """ Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST """ @@ -308,7 +313,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_from_is_negative(self): + def test_from_is_negative(self) -> None: """ Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST """ @@ -326,7 +331,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_next_token(self): + def test_next_token(self) -> None: """ Testing that `next_token` appears at the right place """ @@ -384,7 +389,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertNotIn("next_token", channel.json_body) - def _create_event_and_report(self, room_id, user_tok): + def _create_event_and_report(self, room_id: str, user_tok: str) -> None: """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -397,7 +402,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def _create_event_and_report_without_parameters(self, room_id, user_tok): + def _create_event_and_report_without_parameters( + self, room_id: str, user_tok: str + ) -> None: """Create and report an event, but omit reason and score""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -410,7 +417,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def _check_fields(self, content): + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in an event report""" for c in content: self.assertIn("id", c) @@ -433,7 +440,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): report_event.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -453,7 +460,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): # first created event report gets `id`=2 self.url = "/_synapse/admin/v1/event_reports/2" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to get event report without authentication. """ @@ -466,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -484,7 +491,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_default_success(self): + def test_default_success(self) -> None: """ Testing get a reported event """ @@ -498,7 +505,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self._check_fields(channel.json_body) - def test_invalid_report_id(self): + def test_invalid_report_id(self) -> None: """ Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST. """ @@ -557,7 +564,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_report_id_not_found(self): + def test_report_id_not_found(self) -> None: """ Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND. """ @@ -576,7 +583,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual("Event report not found", channel.json_body["error"]) - def _create_event_and_report(self, room_id, user_tok): + def _create_event_and_report(self, room_id: str, user_tok: str) -> None: """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -589,7 +596,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def _check_fields(self, content): + def _check_fields(self, content: JsonDict) -> None: """Checks that all attributes are present in a event report""" self.assertIn("id", content) self.assertIn("received_ts", content) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 6618279dd1..81e578fd26 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -12,16 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os from http import HTTPStatus from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.server import FakeSite, make_request @@ -39,7 +42,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -48,7 +51,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.filepaths = MediaFilePaths(hs.config.media.media_store_path) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete media without authentication. """ @@ -63,7 +66,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -85,7 +88,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_does_not_exist(self): + def test_media_does_not_exist(self) -> None: """ Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -100,7 +103,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_media_is_not_local(self): + def test_media_is_not_local(self) -> None: """ Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -115,7 +118,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_delete_media(self): + def test_delete_media(self) -> None: """ Tests that delete a media is successfully """ @@ -208,7 +211,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -221,7 +224,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): # Move clock up to somewhat realistic time self.reactor.advance(1000000000) - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete media without authentication. """ @@ -235,7 +238,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -255,7 +258,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_media_is_not_local(self): + def test_media_is_not_local(self) -> None: """ Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST """ @@ -270,7 +273,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) - def test_missing_parameter(self): + def test_missing_parameter(self) -> None: """ If the parameter `before_ts` is missing, an error is returned. """ @@ -290,7 +293,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): "Missing integer query parameter 'before_ts'", channel.json_body["error"] ) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -363,7 +366,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_delete_media_never_accessed(self): + def test_delete_media_never_accessed(self) -> None: """ Tests that media deleted if it is older than `before_ts` and never accessed `last_access_ts` is `NULL` and `created_ts` < `before_ts` @@ -394,7 +397,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_date(self): + def test_keep_media_by_date(self) -> None: """ Tests that media is not deleted if it is newer than `before_ts` """ @@ -431,7 +434,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_size(self): + def test_keep_media_by_size(self) -> None: """ Tests that media is not deleted if its size is smaller than or equal to `size_gt` @@ -466,7 +469,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_user_avatar(self): + def test_keep_media_by_user_avatar(self) -> None: """ Tests that we do not delete media if is used as a user avatar Tests parameter `keep_profiles` @@ -510,7 +513,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def test_keep_media_by_room_avatar(self): + def test_keep_media_by_room_avatar(self) -> None: """ Tests that we do not delete media if it is used as a room avatar Tests parameter `keep_profiles` @@ -555,7 +558,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self._access_media(server_and_media_id, False) - def _create_media(self): + def _create_media(self) -> str: """ Create a media and return media_id and server_and_media_id """ @@ -577,7 +580,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): return server_and_media_id - def _access_media(self, server_and_media_id, expect_success=True): + def _access_media(self, server_and_media_id, expect_success=True) -> None: """ Try to access a media and check the result """ @@ -627,7 +630,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() self.server_name = hs.hostname @@ -652,7 +655,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/media/%s/%s/%s" @parameterized.expand(["quarantine", "unquarantine"]) - def test_no_auth(self, action: str): + def test_no_auth(self, action: str) -> None: """ Try to protect media without authentication. """ @@ -671,7 +674,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["quarantine", "unquarantine"]) - def test_requester_is_no_admin(self, action: str): + def test_requester_is_no_admin(self, action: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -691,7 +694,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_quarantine_media(self): + def test_quarantine_media(self) -> None: """ Tests that quarantining and remove from quarantine a media is successfully """ @@ -725,7 +728,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): media_info = self.get_success(self.store.get_local_media(self.media_id)) self.assertFalse(media_info["quarantined_by"]) - def test_quarantine_protected_media(self): + def test_quarantine_protected_media(self) -> None: """ Tests that quarantining from protected media fails """ @@ -760,7 +763,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: media_repo = hs.get_media_repository_resource() self.store = hs.get_datastore() @@ -784,7 +787,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/media/%s/%s" @parameterized.expand(["protect", "unprotect"]) - def test_no_auth(self, action: str): + def test_no_auth(self, action: str) -> None: """ Try to protect media without authentication. """ @@ -799,7 +802,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @parameterized.expand(["protect", "unprotect"]) - def test_requester_is_no_admin(self, action: str): + def test_requester_is_no_admin(self, action: str) -> None: """ If the user is not a server admin, an error is returned. """ @@ -819,7 +822,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_protect_media(self): + def test_protect_media(self) -> None: """ Tests that protect and unprotect a media is successfully """ @@ -864,7 +867,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname @@ -874,7 +877,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.url = "/_synapse/admin/v1/purge_media_cache" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to delete media without authentication. """ @@ -888,7 +891,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self): + def test_requester_is_not_admin(self) -> None: """ If the user is not a server admin, an error is returned. """ @@ -908,7 +911,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 63087955f2..350a62dda6 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import random import string from http import HTTPStatus +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -29,7 +32,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -39,7 +42,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/registration_tokens" - def _new_token(self, **kwargs): + def _new_token(self, **kwargs) -> str: """Helper function to create a token.""" token = kwargs.get( "token", @@ -61,7 +64,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # CREATION - def test_create_no_auth(self): + def test_create_no_auth(self) -> None: """Try to create a token without authentication.""" channel = self.make_request("POST", self.url + "/new", {}) self.assertEqual( @@ -71,7 +74,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_create_requester_not_admin(self): + def test_create_requester_not_admin(self) -> None: """Try to create a token while not an admin.""" channel = self.make_request( "POST", @@ -86,7 +89,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_create_using_defaults(self): + def test_create_using_defaults(self) -> None: """Create a token using all the defaults.""" channel = self.make_request( "POST", @@ -102,7 +105,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_specifying_fields(self): + def test_create_specifying_fields(self) -> None: """Create a token specifying the value of all fields.""" # As many of the allowed characters as possible with length <= 64 token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-" @@ -126,7 +129,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_with_null_value(self): + def test_create_with_null_value(self) -> None: """Create a token specifying unlimited uses and no expiry.""" data = { "uses_allowed": None, @@ -147,7 +150,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["completed"], 0) - def test_create_token_too_long(self): + def test_create_token_too_long(self) -> None: """Check token longer than 64 chars is invalid.""" data = {"token": "a" * 65} @@ -165,7 +168,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_invalid_chars(self): + def test_create_token_invalid_chars(self) -> None: """Check you can't create token with invalid characters.""" data = { "token": "abc/def", @@ -185,7 +188,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_token_already_exists(self): + def test_create_token_already_exists(self) -> None: """Check you can't create token that already exists.""" data = { "token": "abcd", @@ -208,7 +211,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_unable_to_generate_token(self): + def test_create_unable_to_generate_token(self) -> None: """Check right error is raised when server can't generate unique token.""" # Create all possible single character tokens tokens = [] @@ -239,7 +242,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(500, channel.code, msg=channel.json_body) - def test_create_uses_allowed(self): + def test_create_uses_allowed(self) -> None: """Check you can only create a token with good values for uses_allowed.""" # Should work with 0 (token is invalid from the start) channel = self.make_request( @@ -279,7 +282,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_expiry_time(self): + def test_create_expiry_time(self) -> None: """Check you can't create a token with an invalid expiry_time.""" # Should fail with a time in the past channel = self.make_request( @@ -309,7 +312,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_create_length(self): + def test_create_length(self) -> None: """Check you can only generate a token with a valid length.""" # Should work with 64 channel = self.make_request( @@ -379,7 +382,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # UPDATING - def test_update_no_auth(self): + def test_update_no_auth(self) -> None: """Try to update a token without authentication.""" channel = self.make_request( "PUT", @@ -393,7 +396,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_update_requester_not_admin(self): + def test_update_requester_not_admin(self) -> None: """Try to update a token while not an admin.""" channel = self.make_request( "PUT", @@ -408,7 +411,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_update_non_existent(self): + def test_update_non_existent(self) -> None: """Try to update a token that doesn't exist.""" channel = self.make_request( "PUT", @@ -424,7 +427,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_update_uses_allowed(self): + def test_update_uses_allowed(self) -> None: """Test updating just uses_allowed.""" # Create new token using default values token = self._new_token() @@ -490,7 +493,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_expiry_time(self): + def test_update_expiry_time(self) -> None: """Test updating just expiry_time.""" # Create new token using default values token = self._new_token() @@ -547,7 +550,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) - def test_update_both(self): + def test_update_both(self) -> None: """Test updating both uses_allowed and expiry_time.""" # Create new token using default values token = self._new_token() @@ -569,7 +572,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) - def test_update_invalid_type(self): + def test_update_invalid_type(self) -> None: """Test using invalid types doesn't work.""" # Create new token using default values token = self._new_token() @@ -595,7 +598,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # DELETING - def test_delete_no_auth(self): + def test_delete_no_auth(self) -> None: """Try to delete a token without authentication.""" channel = self.make_request( "DELETE", @@ -609,7 +612,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_delete_requester_not_admin(self): + def test_delete_requester_not_admin(self) -> None: """Try to delete a token while not an admin.""" channel = self.make_request( "DELETE", @@ -624,7 +627,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_delete_non_existent(self): + def test_delete_non_existent(self) -> None: """Try to delete a token that doesn't exist.""" channel = self.make_request( "DELETE", @@ -640,7 +643,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_delete(self): + def test_delete(self) -> None: """Test deleting a token.""" # Create new token using default values token = self._new_token() @@ -656,7 +659,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # GETTING ONE - def test_get_no_auth(self): + def test_get_no_auth(self) -> None: """Try to get a token without authentication.""" channel = self.make_request( "GET", @@ -670,7 +673,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_get_requester_not_admin(self): + def test_get_requester_not_admin(self) -> None: """Try to get a token while not an admin.""" channel = self.make_request( "GET", @@ -685,7 +688,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_get_non_existent(self): + def test_get_non_existent(self) -> None: """Try to get a token that doesn't exist.""" channel = self.make_request( "GET", @@ -701,7 +704,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - def test_get(self): + def test_get(self) -> None: """Test getting a token.""" # Create new token using default values token = self._new_token() @@ -722,7 +725,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # LISTING - def test_list_no_auth(self): + def test_list_no_auth(self) -> None: """Try to list tokens without authentication.""" channel = self.make_request("GET", self.url, {}) self.assertEqual( @@ -732,7 +735,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_list_requester_not_admin(self): + def test_list_requester_not_admin(self) -> None: """Try to list tokens while not an admin.""" channel = self.make_request( "GET", @@ -747,7 +750,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_list_all(self): + def test_list_all(self) -> None: """Test listing all tokens.""" # Create new token using default values token = self._new_token() @@ -768,7 +771,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertEqual(token_info["pending"], 0) self.assertEqual(token_info["completed"], 0) - def test_list_invalid_query_parameter(self): + def test_list_invalid_query_parameter(self) -> None: """Test with `valid` query parameter not `true` or `false`.""" channel = self.make_request( "GET", @@ -783,7 +786,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): msg=channel.json_body, ) - def _test_list_query_parameter(self, valid: str): + def _test_list_query_parameter(self, valid: str) -> None: """Helper used to test both valid=true and valid=false.""" # Create 2 valid and 2 invalid tokens. now = self.hs.get_clock().time_msec() @@ -820,10 +823,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): self.assertIn(token_info_1["token"], tokens) self.assertIn(token_info_2["token"], tokens) - def test_list_valid(self): + def test_list_valid(self) -> None: """Test listing just valid tokens.""" self._test_list_query_parameter(valid="true") - def test_list_invalid(self): + def test_list_invalid(self) -> None: """Test listing just invalid tokens.""" self._test_list_query_parameter(valid="false") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 681f9173ef..d3858e460d 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import urllib.parse from http import HTTPStatus from typing import List, Optional @@ -19,11 +18,15 @@ from unittest.mock import Mock from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -39,7 +42,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): room.register_deprecated_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.event_creation_handler = hs.get_event_creation_handler() hs.config.consent.user_consent_version = "1" @@ -455,7 +458,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): room.register_deprecated_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.event_creation_handler = hs.get_event_creation_handler() hs.config.consent.user_consent_version = "1" @@ -1062,12 +1065,12 @@ class RoomTestCase(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Create user self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - def test_list_rooms(self): + def test_list_rooms(self) -> None: """Test that we can list rooms""" # Create 3 test rooms total_rooms = 3 @@ -1131,7 +1134,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # We shouldn't receive a next token here as there's no further rooms to show self.assertNotIn("next_batch", channel.json_body) - def test_list_rooms_pagination(self): + def test_list_rooms_pagination(self) -> None: """Test that we can get a full list of rooms through pagination""" # Create 5 test rooms total_rooms = 5 @@ -1213,7 +1216,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - def test_correct_room_attributes(self): + def test_correct_room_attributes(self) -> None: """Test the correct attributes for a room are returned""" # Create a test room room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1294,7 +1297,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_alias, r["canonical_alias"]) - def test_room_list_sort_order(self): + def test_room_list_sort_order(self) -> None: """Test room list sort ordering. alphabetical name versus number of members, reversing the order, etc. """ @@ -1303,7 +1306,7 @@ class RoomTestCase(unittest.HomeserverTestCase): order_type: str, expected_room_list: List[str], reverse: bool = False, - ): + ) -> None: """Request the list of rooms in a certain order. Assert that order is what we expect @@ -1432,7 +1435,7 @@ class RoomTestCase(unittest.HomeserverTestCase): _order_test("state_events", [room_id_3, room_id_2, room_id_1]) _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) - def test_search_term(self): + def test_search_term(self) -> None: """Test that searching for a room works correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1461,7 +1464,7 @@ class RoomTestCase(unittest.HomeserverTestCase): expected_room_id: Optional[str], search_term: str, expected_http_code: int = HTTPStatus.OK, - ): + ) -> None: """Search for a room and check that the returned room's id is a match Args: @@ -1535,7 +1538,7 @@ class RoomTestCase(unittest.HomeserverTestCase): # Test search local part of alias _search_test(room_id_1, "alias1") - def test_search_term_non_ascii(self): + def test_search_term_non_ascii(self) -> None: """Test that searching for a room with non-ASCII characters works correctly""" # Create test room @@ -1562,7 +1565,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) - def test_single_room(self): + def test_single_room(self) -> None: """Test that a single room can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1613,7 +1616,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id_1, channel.json_body["room_id"]) - def test_single_room_devices(self): + def test_single_room_devices(self) -> None: """Test that `joined_local_devices` can be requested correctly""" room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1652,7 +1655,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) - def test_room_members(self): + def test_room_members(self) -> None: """Test that room members can be requested correctly""" # Create two test rooms room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1700,7 +1703,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["total"], 3) - def test_room_state(self): + def test_room_state(self) -> None: """Test that room state can be requested correctly""" # Create two test rooms room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) @@ -1717,7 +1720,9 @@ class RoomTestCase(unittest.HomeserverTestCase): # the create_room already does the right thing, so no need to verify that we got # the state events it created. - def _set_canonical_alias(self, room_id: str, test_alias: str, admin_user_tok: str): + def _set_canonical_alias( + self, room_id: str, test_alias: str, admin_user_tok: str + ) -> None: # Create a new alias to this room url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) channel = self.make_request( @@ -1752,7 +1757,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -1767,7 +1772,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): ) self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -1782,7 +1787,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If a parameter is missing, return an error """ @@ -1797,7 +1802,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) - def test_local_user_does_not_exist(self): + def test_local_user_does_not_exist(self) -> None: """ Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND """ @@ -1812,7 +1817,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_remote_user(self): + def test_remote_user(self) -> None: """ Check that only local user can join rooms. """ @@ -1830,7 +1835,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_room_does_not_exist(self): + def test_room_does_not_exist(self) -> None: """ Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. """ @@ -1846,7 +1851,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) - def test_room_is_not_valid(self): + def test_room_is_not_valid(self) -> None: """ Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. """ @@ -1865,7 +1870,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_join_public_room(self): + def test_join_public_room(self) -> None: """ Test joining a local user to a public room with "JoinRules.PUBLIC" """ @@ -1890,7 +1895,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_not_member(self): + def test_join_private_room_if_not_member(self) -> None: """ Test joining a local user to a private room with "JoinRules.INVITE" when server admin is not member of this room. @@ -1910,7 +1915,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_join_private_room_if_member(self): + def test_join_private_room_if_member(self) -> None: """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is member of this room. @@ -1961,7 +1966,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_join_private_room_if_owner(self): + def test_join_private_room_if_owner(self) -> None: """ Test joining a local user to a private room with "JoinRules.INVITE", when server admin is owner of this room. @@ -1991,7 +1996,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) - def test_context_as_non_admin(self): + def test_context_as_non_admin(self) -> None: """ Test that, without being admin, one cannot use the context admin API """ @@ -2025,7 +2030,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_context_as_admin(self): + def test_context_as_admin(self) -> None: """ Test that, as admin, we can find the context of an event without having joined the room. """ @@ -2081,7 +2086,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, homeserver): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2098,7 +2103,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): self.public_room_id ) - def test_public_room(self): + def test_public_room(self) -> None: """Test that getting admin in a public room works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2123,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def test_private_room(self): + def test_private_room(self) -> None: """Test that getting admin in a private room works and we get invited.""" room_id = self.helper.create_room_as( self.creator, @@ -2151,7 +2156,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.admin_user_tok, ) - def test_other_user(self): + def test_other_user(self) -> None: """Test that giving admin in a public room works to a non-admin user works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2176,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): tok=self.second_tok, ) - def test_not_enough_power(self): + def test_not_enough_power(self) -> None: """Test that we get a sensible error if there are no local room admins.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True @@ -2216,7 +2221,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self._store = hs.get_datastore() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -2231,7 +2236,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/rooms/%s/block" @parameterized.expand([("PUT",), ("GET",)]) - def test_requester_is_no_admin(self, method: str): + def test_requester_is_no_admin(self, method: str) -> None: """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.""" channel = self.make_request( @@ -2245,7 +2250,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @parameterized.expand([("PUT",), ("GET",)]) - def test_room_is_not_valid(self, method: str): + def test_room_is_not_valid(self, method: str) -> None: """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.""" channel = self.make_request( @@ -2261,7 +2266,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_block_is_not_valid(self): + def test_block_is_not_valid(self) -> None: """If parameter `block` is not valid, return an error.""" # `block` is not valid @@ -2296,7 +2301,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) - def test_block_room(self): + def test_block_room(self) -> None: """Test that block a room is successful.""" def _request_and_test_block_room(room_id: str) -> None: @@ -2320,7 +2325,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): # unknown remote room _request_and_test_block_room("!unknown:remote") - def test_block_room_twice(self): + def test_block_room_twice(self) -> None: """Test that block a room that is already blocked is successful.""" self._is_blocked(self.room_id, expect=False) @@ -2335,7 +2340,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertTrue(channel.json_body["block"]) self._is_blocked(self.room_id, expect=True) - def test_unblock_room(self): + def test_unblock_room(self) -> None: """Test that unblock a room is successful.""" def _request_and_test_unblock_room(room_id: str) -> None: @@ -2360,7 +2365,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): # unknown remote room _request_and_test_unblock_room("!unknown:remote") - def test_unblock_room_twice(self): + def test_unblock_room_twice(self) -> None: """Test that unblock a room that is not blocked is successful.""" self._block_room(self.room_id) @@ -2375,7 +2380,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body["block"]) self._is_blocked(self.room_id, expect=False) - def test_get_blocked_room(self): + def test_get_blocked_room(self) -> None: """Test get status of a blocked room""" def _request_blocked_room(room_id: str) -> None: @@ -2399,7 +2404,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase): # unknown remote room _request_blocked_room("!unknown:remote") - def test_get_unblocked_room(self): + def test_get_unblocked_room(self) -> None: """Test get status of a unblocked room""" def _request_unblocked_room(room_id: str) -> None: diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 0b9da4c732..3c59f5f766 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from http import HTTPStatus from typing import List +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login, room, sync +from synapse.server import HomeServer from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -34,7 +37,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastore() self.room_shutdown_handler = hs.get_room_shutdown_handler() self.pagination_handler = hs.get_pagination_handler() @@ -49,7 +52,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/send_server_notice" - def test_no_auth(self): + def test_no_auth(self) -> None: """Try to send a server notice without authentication.""" channel = self.make_request("POST", self.url) @@ -60,7 +63,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """If the user is not a server admin, an error is returned.""" channel = self.make_request( "POST", @@ -76,7 +79,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_does_not_exist(self): + def test_user_does_not_exist(self) -> None: """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" channel = self.make_request( "POST", @@ -89,7 +92,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_user_is_not_local(self): + def test_user_is_not_local(self) -> None: """ Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST """ @@ -109,7 +112,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """If parameters are invalid, an error is returned.""" # no content, no user @@ -157,7 +160,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"]) - def test_server_notice_disabled(self): + def test_server_notice_disabled(self) -> None: """Tests that server returns error if server notice is disabled""" channel = self.make_request( "POST", @@ -176,7 +179,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): ) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice(self): + def test_send_server_notice(self) -> None: """ Tests that sending two server notices is successfully, the server uses the same room and do not send messages twice. @@ -240,7 +243,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(messages[1]["sender"], "@notices:test") @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_leave_room(self): + def test_send_server_notice_leave_room(self) -> None: """ Tests that sending a server notices is successfully. The user leaves the room and the second message appears @@ -324,7 +327,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertNotEqual(first_room_id, second_room_id) @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) - def test_send_server_notice_delete_room(self): + def test_send_server_notice_delete_room(self) -> None: """ Tests that the user get server notice in a new room after the first server notice room was deleted. @@ -414,7 +417,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> RoomsForUser: + ) -> List[RoomsForUser]: """Check invite and room membership status of a user. Args diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 43d8ca032b..7cb8ec57ba 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -12,13 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from http import HTTPStatus -from typing import Any, Dict, List, Optional +from typing import List, Optional + +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import SMALL_PNG @@ -30,7 +34,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -41,7 +45,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/statistics/users/media" - def test_no_auth(self): + def test_no_auth(self) -> None: """ Try to list users without authentication. """ @@ -54,7 +58,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + def test_requester_is_no_admin(self) -> None: """ If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. """ @@ -72,7 +76,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_invalid_parameter(self): + def test_invalid_parameter(self) -> None: """ If parameters are invalid, an error is returned. """ @@ -188,7 +192,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_limit(self): + def test_limit(self) -> None: """ Testing list of media with limit """ @@ -206,7 +210,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["users"]) - def test_from(self): + def test_from(self) -> None: """ Testing list of media with a defined starting point (from) """ @@ -224,7 +228,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["users"]) - def test_limit_and_from(self): + def test_limit_and_from(self) -> None: """ Testing list of media with a defined starting point and limit """ @@ -242,7 +246,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["users"]), 10) self._check_fields(channel.json_body["users"]) - def test_next_token(self): + def test_next_token(self) -> None: """ Testing that `next_token` appears at the right place """ @@ -302,7 +306,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_no_media(self): + def test_no_media(self) -> None: """ Tests that a normal lookup for statistics is successfully if users have no media created @@ -318,7 +322,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["users"])) - def test_order_by(self): + def test_order_by(self) -> None: """ Testing order list with parameter `order_by` """ @@ -396,7 +400,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): "b", ) - def test_from_until_ts(self): + def test_from_until_ts(self) -> None: """ Testing filter by time with parameters `from_ts` and `until_ts` """ @@ -448,7 +452,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) - def test_search_term(self): + def test_search_term(self) -> None: self._create_users_with_media(20, 1) # check without filter get all users @@ -488,7 +492,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], 0) - def _create_users_with_media(self, number_users: int, media_per_user: int): + def _create_users_with_media(self, number_users: int, media_per_user: int) -> None: """ Create a number of users with a number of media Args: @@ -500,7 +504,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): user_tok = self.login("foo_user_%s" % i, "pass") self._create_media(user_tok, media_per_user) - def _create_media(self, user_token: str, number_media: int): + def _create_media(self, user_token: str, number_media: int) -> None: """ Create a number of media for a specific user Args: @@ -514,7 +518,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK ) - def _check_fields(self, content: List[Dict[str, Any]]): + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in content Args: content: List that is checked for content @@ -527,7 +531,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): def _order_test( self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None - ): + ) -> None: """Request the list of users in a certain order. Assert that order is what we expect Args: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 03aa689ace..4fedd5fd08 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -127,14 +127,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -153,7 +153,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, @@ -161,7 +161,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -177,14 +177,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob", "password": "abc123", "admin": True, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -308,13 +308,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob1", "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -332,14 +332,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob2", "displayname": None, "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -356,14 +356,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob3", "displayname": "", "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -379,14 +379,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, "username": "bob4", "displayname": "Bob's Name", "password": "abc123", - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -426,7 +426,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update( nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" ) - want_mac = want_mac.hexdigest() + want_mac_str = want_mac.hexdigest() body = { "nonce": nonce, @@ -434,7 +434,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): "password": "abc123", "admin": True, "user_type": UserTypes.SUPPORT, - "mac": want_mac, + "mac": want_mac_str, } channel = self.make_request("POST", self.url, body) @@ -870,7 +870,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(expected_user_list, returned_order) self._check_fields(channel.json_body["users"]) - def _check_fields(self, content: JsonDict): + def _check_fields(self, content: List[JsonDict]): """Checks that the expected user attributes are present in content Args: content: List that is checked for content @@ -3235,7 +3235,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): return media_id - def _check_fields(self, content: JsonDict): + def _check_fields(self, content: List[JsonDict]): """Checks that the expected user attributes are present in content Args: content: List that is checked for content -- cgit 1.5.1 From 637df95de63196033a6da4a6e286e1d58ea517b6 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 3 Dec 2021 16:42:44 +0000 Subject: Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. (#11445) --- changelog.d/11445.feature | 1 + synapse/config/registration.py | 49 ++++++++++++++++++++ synapse/handlers/register.py | 20 ++++++-- tests/config/test_registration_config.py | 78 ++++++++++++++++++++++++++++++++ tests/rest/client/test_auth.py | 76 +++++++++++++++++++++++++++++++ 5 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11445.feature create mode 100644 tests/config/test_registration_config.py (limited to 'tests') diff --git a/changelog.d/11445.feature b/changelog.d/11445.feature new file mode 100644 index 0000000000..211a722b65 --- /dev/null +++ b/changelog.d/11445.feature @@ -0,0 +1 @@ +Support configuring the lifetime of non-refreshable access tokens separately to refreshable access tokens. \ No newline at end of file diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 47853199f4..68a4985398 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -130,11 +130,60 @@ class RegistrationConfig(Config): int ] = refreshable_access_token_lifetime + if ( + self.session_lifetime is not None + and "refreshable_access_token_lifetime" in config + ): + if self.session_lifetime < self.refreshable_access_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `refreshable_access_token_lifetime` " + "configuration options have been set, but `refreshable_access_token_lifetime` " + " exceeds `session_lifetime`!" + ) + + # The `nonrefreshable_access_token_lifetime` applies for tokens that can NOT be + # refreshed using a refresh token. + # If it is None, then these tokens last for the entire length of the session, + # which is infinite by default. + # The intention behind this configuration option is to help with requiring + # all clients to use refresh tokens, if the homeserver administrator requires. + nonrefreshable_access_token_lifetime = config.get( + "nonrefreshable_access_token_lifetime", + None, + ) + if nonrefreshable_access_token_lifetime is not None: + nonrefreshable_access_token_lifetime = self.parse_duration( + nonrefreshable_access_token_lifetime + ) + self.nonrefreshable_access_token_lifetime = nonrefreshable_access_token_lifetime + + if ( + self.session_lifetime is not None + and self.nonrefreshable_access_token_lifetime is not None + ): + if self.session_lifetime < self.nonrefreshable_access_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `nonrefreshable_access_token_lifetime` " + "configuration options have been set, but `nonrefreshable_access_token_lifetime` " + " exceeds `session_lifetime`!" + ) + refresh_token_lifetime = config.get("refresh_token_lifetime") if refresh_token_lifetime is not None: refresh_token_lifetime = self.parse_duration(refresh_token_lifetime) self.refresh_token_lifetime: Optional[int] = refresh_token_lifetime + if ( + self.session_lifetime is not None + and self.refresh_token_lifetime is not None + ): + if self.session_lifetime < self.refresh_token_lifetime: + raise ConfigError( + "Both `session_lifetime` and `refresh_token_lifetime` " + "configuration options have been set, but `refresh_token_lifetime` " + " exceeds `session_lifetime`!" + ) + # The fallback template used for authenticating using a registration token self.registration_token_template = self.read_template("registration_token.html") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 24ca11b924..b14ddd8267 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -1,4 +1,5 @@ # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -116,6 +117,9 @@ class RegistrationHandler: self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.registration.session_lifetime + self.nonrefreshable_access_token_lifetime = ( + hs.config.registration.nonrefreshable_access_token_lifetime + ) self.refreshable_access_token_lifetime = ( hs.config.registration.refreshable_access_token_lifetime ) @@ -794,13 +798,25 @@ class RegistrationHandler: class and RegisterDeviceReplicationServlet. """ assert not self.hs.config.worker.worker_app + now_ms = self.clock.time_msec() access_token_expiry = None if self.session_lifetime is not None: if is_guest: raise Exception( "session_lifetime is not currently implemented for guest access" ) - access_token_expiry = self.clock.time_msec() + self.session_lifetime + access_token_expiry = now_ms + self.session_lifetime + + if self.nonrefreshable_access_token_lifetime is not None: + if access_token_expiry is not None: + # Don't allow the non-refreshable access token to outlive the + # session. + access_token_expiry = min( + now_ms + self.nonrefreshable_access_token_lifetime, + access_token_expiry, + ) + else: + access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime refresh_token = None refresh_token_id = None @@ -818,8 +834,6 @@ class RegistrationHandler: # that this value is set before setting this flag). assert self.refreshable_access_token_lifetime is not None - now_ms = self.clock.time_msec() - # Set the expiry time of the refreshable access token access_token_expiry = now_ms + self.refreshable_access_token_lifetime diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py new file mode 100644 index 0000000000..17a84d20d8 --- /dev/null +++ b/tests/config/test_registration_config.py @@ -0,0 +1,78 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config import ConfigError +from synapse.config.homeserver import HomeServerConfig + +from tests.unittest import TestCase +from tests.utils import default_config + + +class RegistrationConfigTestCase(TestCase): + def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self): + """ + session_lifetime should logically be larger than, or at least as large as, + all the different token lifetimes. + Test that the user is faced with configuration errors if they make it + smaller, as that configuration doesn't make sense. + """ + config_dict = default_config("test") + + # First test all the error conditions + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "nonrefreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "refreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + with self.assertRaises(ConfigError): + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "30m", + "refresh_token_lifetime": "31m", + **config_dict, + } + ) + + # Then test all the fine conditions + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "31m", + "nonrefreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + HomeServerConfig().parse_config_dict( + { + "session_lifetime": "31m", + "refreshable_access_token_lifetime": "31m", + **config_dict, + } + ) + + HomeServerConfig().parse_config_dict( + {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict} + ) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index d8a94f4c12..7239e1a1b5 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -524,6 +524,19 @@ class RefreshAuthTests(unittest.HomeserverTestCase): {"refresh_token": refresh_token}, ) + def is_access_token_valid(self, access_token) -> bool: + """ + Checks whether an access token is valid, returning whether it is or not. + """ + code = self.make_request( + "GET", "/_matrix/client/v3/account/whoami", access_token=access_token + ).code + + # Either 200 or 401 is what we get back; anything else is a bug. + assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED} + + return code == HTTPStatus.OK + def test_login_issue_refresh_token(self): """ A login response should include a refresh_token only if asked. @@ -671,6 +684,69 @@ class RefreshAuthTests(unittest.HomeserverTestCase): HTTPStatus.UNAUTHORIZED, ) + @override_config( + { + "refreshable_access_token_lifetime": "1m", + "nonrefreshable_access_token_lifetime": "10m", + } + ) + def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self): + """ + Tests that the expiry times for refreshable and non-refreshable access + tokens can be different. + """ + body = { + "type": "m.login.password", + "user": "test", + "password": self.user_pass, + } + login_response1 = self.make_request( + "POST", + "/_matrix/client/r0/login", + {"org.matrix.msc2918.refresh_token": True, **body}, + ) + self.assertEqual(login_response1.code, 200, login_response1.result) + self.assertApproximates( + login_response1.json_body["expires_in_ms"], 60 * 1000, 100 + ) + refreshable_access_token = login_response1.json_body["access_token"] + + login_response2 = self.make_request( + "POST", + "/_matrix/client/r0/login", + body, + ) + self.assertEqual(login_response2.code, 200, login_response2.result) + nonrefreshable_access_token = login_response2.json_body["access_token"] + + # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry) + self.reactor.advance(59.0) + + # Both tokens should still be valid. + self.assertTrue(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 61 s (just past 1 minute, the time of expiry) + self.reactor.advance(2.0) + + # Only the non-refreshable token is still valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 599 s (just shy of 10 minutes, the time of expiry) + self.reactor.advance(599.0 - 61.0) + + # It's still the case that only the non-refreshable token is still valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token)) + + # Advance to 601 s (just past 10 minutes, the time of expiry) + self.reactor.advance(2.0) + + # Now neither token is valid. + self.assertFalse(self.is_access_token_valid(refreshable_access_token)) + self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token)) + @override_config( {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"} ) -- cgit 1.5.1 From a77c36989785c0d5565ab9a1169f4f88e512ce8a Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 6 Dec 2021 11:36:08 +0000 Subject: Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common` (#11505) --- changelog.d/11505.misc | 1 + synapse/config/room_directory.py | 3 +- synapse/config/tls.py | 3 +- synapse/federation/federation_server.py | 3 +- synapse/push/push_rule_evaluator.py | 7 ++-- synapse/python_dependencies.py | 1 + synapse/util/__init__.py | 59 +-------------------------------- tests/util/test_glob_to_regex.py | 59 --------------------------------- 8 files changed, 13 insertions(+), 123 deletions(-) create mode 100644 changelog.d/11505.misc delete mode 100644 tests/util/test_glob_to_regex.py (limited to 'tests') diff --git a/changelog.d/11505.misc b/changelog.d/11505.misc new file mode 100644 index 0000000000..926b562fad --- /dev/null +++ b/changelog.d/11505.misc @@ -0,0 +1 @@ +Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common`. diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 57316c59b6..3c5e0f7ce7 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -15,8 +15,9 @@ from typing import List +from matrix_common.regex import glob_to_regex + from synapse.types import JsonDict -from synapse.util import glob_to_regex from ._base import Config, ConfigError diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 4ca111618f..3e235b57a7 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -16,11 +16,12 @@ import logging import os from typing import List, Optional, Pattern +from matrix_common.regex import glob_to_regex + from OpenSSL import SSL, crypto from twisted.internet._sslverify import Certificate, trustRootFromCertificates from synapse.config._base import Config, ConfigError -from synapse.util import glob_to_regex logger = logging.getLogger(__name__) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8e37e76206..4697a62c18 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,6 +28,7 @@ from typing import ( Union, ) +from matrix_common.regex import glob_to_regex from prometheus_client import Counter, Gauge, Histogram from twisted.internet import defer @@ -66,7 +67,7 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict, get_domain_from_id -from synapse.util import glob_to_regex, json_decoder, unwrapFirstError +from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 7f68092ec5..659a53805d 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -17,9 +17,10 @@ import logging import re from typing import Any, Dict, List, Optional, Pattern, Tuple, Union +from matrix_common.regex import glob_to_regex, to_word_pattern + from synapse.events import EventBase from synapse.types import JsonDict, UserID -from synapse.util import glob_to_regex, re_word_boundary from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -184,7 +185,7 @@ class PushRuleEvaluatorForEvent: r = regex_cache.get((display_name, False, True), None) if not r: r1 = re.escape(display_name) - r1 = re_word_boundary(r1) + r1 = to_word_pattern(r1) r = re.compile(r1, flags=re.IGNORECASE) regex_cache[(display_name, False, True)] = r @@ -213,7 +214,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: try: r = regex_cache.get((glob, True, word_boundary), None) if not r: - r = glob_to_regex(glob, word_boundary) + r = glob_to_regex(glob, word_boundary=word_boundary) regex_cache[(glob, True, word_boundary)] = r return bool(r.search(value)) except re.error: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 7d26954244..386debd7db 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -87,6 +87,7 @@ REQUIREMENTS = [ # with the latest security patches. "cryptography>=3.4.7", "ijson>=3.1", + "matrix-common==1.0.0", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 95f23e27b6..f157132210 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -14,9 +14,8 @@ import json import logging -import re import typing -from typing import Any, Callable, Dict, Generator, Optional, Pattern +from typing import Any, Callable, Dict, Generator, Optional import attr from frozendict import frozendict @@ -35,9 +34,6 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) -_WILDCARD_RUN = re.compile(r"([\?\*]+)") - - def _reject_invalid_json(val: Any) -> None: """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) @@ -185,56 +181,3 @@ def log_failure( if not consumeErrors: return failure return None - - -def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern: - """Converts a glob to a compiled regex object. - - Args: - glob: pattern to match - word_boundary: If True, the pattern will be allowed to match at word boundaries - anywhere in the string. Otherwise, the pattern is anchored at the start and - end of the string. - - Returns: - compiled regex pattern - """ - - # Patterns with wildcards must be simplified to avoid performance cliffs - # - The glob `?**?**?` is equivalent to the glob `???*` - # - The glob `???*` is equivalent to the regex `.{3,}` - chunks = [] - for chunk in _WILDCARD_RUN.split(glob): - # No wildcards? re.escape() - if not _WILDCARD_RUN.match(chunk): - chunks.append(re.escape(chunk)) - continue - - # Wildcards? Simplify. - qmarks = chunk.count("?") - if "*" in chunk: - chunks.append(".{%d,}" % qmarks) - else: - chunks.append(".{%d}" % qmarks) - - res = "".join(chunks) - - if word_boundary: - res = re_word_boundary(res) - else: - # \A anchors at start of string, \Z at end of string - res = r"\A" + res + r"\Z" - - return re.compile(res, re.IGNORECASE) - - -def re_word_boundary(r: str) -> str: - """ - Adds word boundary characters to the start and end of an - expression to require that the match occur as a whole word, - but do so respecting the fact that strings starting or ending - with non-word characters will change word boundaries. - """ - # we can't use \b as it chokes on unicode. however \W seems to be okay - # as shorthand for [^0-9A-Za-z_]. - return r"(^|\W)%s(\W|$)" % (r,) diff --git a/tests/util/test_glob_to_regex.py b/tests/util/test_glob_to_regex.py deleted file mode 100644 index 220accb92b..0000000000 --- a/tests/util/test_glob_to_regex.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from synapse.util import glob_to_regex - -from tests.unittest import TestCase - - -class GlobToRegexTestCase(TestCase): - def test_literal_match(self): - """patterns without wildcards should match""" - pat = glob_to_regex("foobaz") - self.assertTrue( - pat.match("FoobaZ"), "patterns should match and be case-insensitive" - ) - self.assertFalse( - pat.match("x foobaz"), "pattern should not match at word boundaries" - ) - - def test_wildcard_match(self): - pat = glob_to_regex("f?o*baz") - - self.assertTrue( - pat.match("FoobarbaZ"), - "* should match string and pattern should be case-insensitive", - ) - self.assertTrue(pat.match("foobaz"), "* should match 0 characters") - self.assertFalse(pat.match("fooxaz"), "the character after * must match") - self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters") - self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters") - - def test_multi_wildcard(self): - """patterns with multiple wildcards in a row should match""" - pat = glob_to_regex("**baz") - self.assertTrue(pat.match("agsgsbaz"), "** should match any string") - self.assertTrue(pat.match("baz"), "** should match the empty string") - self.assertEqual(pat.pattern, r"\A.{0,}baz\Z") - - pat = glob_to_regex("*?baz") - self.assertTrue(pat.match("agsgsbaz"), "*? should match any string") - self.assertTrue(pat.match("abaz"), "*? should match a single char") - self.assertFalse(pat.match("baz"), "*? should not match the empty string") - self.assertEqual(pat.pattern, r"\A.{1,}baz\Z") - - pat = glob_to_regex("a?*?*?baz") - self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars") - self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars") - self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars") - self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z") -- cgit 1.5.1 From 494ebd7347ba52d702802fba4c3bb13e7bfbc2cf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Dec 2021 10:51:15 -0500 Subject: Include bundled aggregations in /sync and related fixes (#11478) Due to updates to MSC2675 this includes a few fixes: * Include bundled aggregations for /sync. * Do not include bundled aggregations for /initialSync and /events. * Do not bundle aggregations for state events. * Clarifies comments and variable names. --- changelog.d/11478.bugfix | 1 + synapse/events/utils.py | 58 ++++++++++------ synapse/handlers/events.py | 5 +- synapse/handlers/initial_sync.py | 30 ++++++-- synapse/handlers/message.py | 8 +-- synapse/rest/admin/rooms.py | 13 +--- synapse/rest/client/relations.py | 9 ++- synapse/rest/client/room.py | 5 +- synapse/rest/client/sync.py | 6 +- tests/rest/client/test_relations.py | 135 +++++++++++++++++++++++++----------- 10 files changed, 169 insertions(+), 101 deletions(-) create mode 100644 changelog.d/11478.bugfix (limited to 'tests') diff --git a/changelog.d/11478.bugfix b/changelog.d/11478.bugfix new file mode 100644 index 0000000000..5f02636f50 --- /dev/null +++ b/changelog.d/11478.bugfix @@ -0,0 +1 @@ +Include bundled relation aggregations during a limited `/sync` request, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 05219a9dd0..84ef69df67 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -306,6 +306,7 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: def serialize_event( e: Union[JsonDict, EventBase], time_now_ms: int, + *, as_client_event: bool = True, event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, token_id: Optional[str] = None, @@ -393,7 +394,8 @@ class EventClientSerializer: self, event: Union[JsonDict, EventBase], time_now: int, - bundle_relations: bool = True, + *, + bundle_aggregations: bool = True, **kwargs: Any, ) -> JsonDict: """Serializes a single event. @@ -401,8 +403,9 @@ class EventClientSerializer: Args: event: The event being serialized. time_now: The current time in milliseconds - bundle_relations: Whether to include the bundled relations for this - event. + bundle_aggregations: Whether to include the bundled aggregations for this + event. Only applies to non-state events. (State events never include + bundled aggregations.) **kwargs: Arguments to pass to `serialize_event` Returns: @@ -414,20 +417,27 @@ class EventClientSerializer: serialized_event = serialize_event(event, time_now, **kwargs) - # If MSC1849 is enabled then we need to look if there are any relations - # we need to bundle in with the event. - # Do not bundle relations if the event has been redacted - if not event.internal_metadata.is_redacted() and ( - self._msc1849_enabled and bundle_relations + # Check if there are any bundled aggregations to include with the event. + # + # Do not bundle aggregations if any of the following at true: + # + # * Support is disabled via the configuration or the caller. + # * The event is a state event. + # * The event has been redacted. + if ( + self._msc1849_enabled + and bundle_aggregations + and not event.is_state() + and not event.internal_metadata.is_redacted() ): - await self._injected_bundled_relations(event, time_now, serialized_event) + await self._injected_bundled_aggregations(event, time_now, serialized_event) return serialized_event - async def _injected_bundled_relations( + async def _injected_bundled_aggregations( self, event: EventBase, time_now: int, serialized_event: JsonDict ) -> None: - """Potentially injects bundled relations into the unsigned portion of the serialized event. + """Potentially injects bundled aggregations into the unsigned portion of the serialized event. Args: event: The event being serialized. @@ -435,7 +445,7 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ - # Do not bundle relations for an event which represents an edit or an + # Do not bundle aggregations for an event which represents an edit or an # annotation. It does not make sense for them to have related events. relates_to = event.content.get("m.relates_to") if isinstance(relates_to, (dict, frozendict)): @@ -445,18 +455,18 @@ class EventClientSerializer: event_id = event.event_id - # The bundled relations to include. - relations = {} + # The bundled aggregations to include. + aggregations = {} annotations = await self.store.get_aggregation_groups_for_event(event_id) if annotations.chunk: - relations[RelationTypes.ANNOTATION] = annotations.to_dict() + aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() references = await self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - relations[RelationTypes.REFERENCE] = references.to_dict() + aggregations[RelationTypes.REFERENCE] = references.to_dict() edit = None if event.type == EventTypes.Message: @@ -482,7 +492,7 @@ class EventClientSerializer: else: serialized_event["content"].pop("m.relates_to", None) - relations[RelationTypes.REPLACE] = { + aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, @@ -495,17 +505,19 @@ class EventClientSerializer: latest_thread_event, ) = await self.store.get_thread_summary(event_id) if latest_thread_event: - relations[RelationTypes.THREAD] = { - # Don't bundle relations as this could recurse forever. + aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_relations=False + latest_thread_event, time_now, bundle_aggregations=False ), "count": thread_count, } - # If any bundled relations were found, include them. - if relations: - serialized_event["unsigned"].setdefault("m.relations", {}).update(relations) + # If any bundled aggregations were found, include them. + if aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + aggregations + ) async def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index b4ff935546..32b0254c5f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -122,9 +122,8 @@ class EventStreamHandler: events, time_now, as_client_event=as_client_event, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_relations=False, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, ) chunk = { diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d4e4556155..9cd21e7f2b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -165,7 +165,11 @@ class InitialSyncHandler: invite_event = await self.store.get_event(event.event_id) d["invite"] = await self._event_serializer.serialize_event( - invite_event, time_now, as_client_event + invite_event, + time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) rooms_ret.append(d) @@ -216,7 +220,11 @@ class InitialSyncHandler: d["messages"] = { "chunk": ( await self._event_serializer.serialize_events( - messages, time_now=time_now, as_client_event=as_client_event + messages, + time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, + as_client_event=as_client_event, ) ), "start": await start_token.to_string(self.store), @@ -226,6 +234,8 @@ class InitialSyncHandler: d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, + # Don't bundle aggregations as this is a deprecated API. + bundle_aggregations=False, as_client_event=as_client_event, ) @@ -366,14 +376,18 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( + # Don't bundle aggregations as this is a deprecated API. await self._event_serializer.serialize_events( - room_state.values(), time_now + room_state.values(), time_now, bundle_aggregations=False ) ), "presence": [], @@ -392,8 +406,9 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() + # Don't bundle aggregations as this is a deprecated API. state = await self._event_serializer.serialize_events( - current_state.values(), time_now + current_state.values(), time_now, bundle_aggregations=False ) now_token = self.hs.get_event_sources().get_current_token() @@ -467,7 +482,10 @@ class InitialSyncHandler: "room_id": room_id, "messages": { "chunk": ( - await self._event_serializer.serialize_events(messages, time_now) + # Don't bundle aggregations as this is a deprecated API. + await self._event_serializer.serialize_events( + messages, time_now, bundle_aggregations=False + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 95b4fad3c6..87f671708c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -247,13 +247,7 @@ class MessageHandler: room_state = room_state_events[membership_event_id] now = self.clock.time_msec() - events = await self._event_serializer.serialize_events( - room_state.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_relations=False, - ) + events = await self._event_serializer.serialize_events(room_state.values(), now) return events async def get_joined_members(self, requester: Requester, room_id: str) -> dict: diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6bbc5510f0..669ab44a45 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -449,13 +449,7 @@ class RoomStateRestServlet(RestServlet): event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events( - events.values(), - now, - # We don't bother bundling aggregations in when asked for state - # events, as clients won't use them. - bundle_relations=False, - ) + room_state = await self._event_serializer.serialize_events(events.values(), now) ret = {"state": room_state} return HTTPStatus.OK, ret @@ -789,10 +783,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_relations=False, + results["state"], time_now ) return HTTPStatus.OK, results diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b1a3304849..fc4e6921c5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -224,14 +224,13 @@ class RelationPaginationServlet(RestServlet): ) now = self.clock.time_msec() - # We set bundle_relations to False when retrieving the original - # event because we want the content before relations were applied to - # it. + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. original_event = await self._event_serializer.serialize_event( - event, now, bundle_relations=False + event, now, bundle_aggregations=False ) # The relations returned for the requested event do include their - # bundled relations. + # bundled aggregations. serialized_events = await self._event_serializer.serialize_events(events, now) return_value = pagination_chunk.to_dict() diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 3598967be0..f48e2e6ca2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -716,10 +716,7 @@ class RoomEventContextServlet(RestServlet): results["events_after"], time_now ) results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_relations=False, + results["state"], time_now ) return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index b6a2485732..88e4f5e063 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -520,9 +520,9 @@ class SyncRestServlet(RestServlet): return self._event_serializer.serialize_events( events, time_now=time_now, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_relations=False, + # Don't bother to bundle aggregations if the timeline is unlimited, + # as clients will have all the necessary information. + bundle_aggregations=room.timeline.limited, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index b494da5138..397c12c2a6 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client import login, register, relations, room +from synapse.rest.client import login, register, relations, room, sync from tests import unittest from tests.server import FakeChannel @@ -29,6 +29,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): servlets = [ relations.register_servlets, room.register_servlets, + sync.register_servlets, login.register_servlets, register.register_servlets, admin.register_servlets_for_client_rest_resource, @@ -454,11 +455,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) - def test_aggregation_get_event(self): - """Test that annotations, references, and threads get correctly bundled when - getting the parent event. - """ - + def test_bundled_aggregations(self): + """Test that annotations, references, and threads get correctly bundled.""" + # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -485,49 +484,107 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) + def assert_bundle(actual): + """Assert the expected values of the bundled aggregations.""" - self.assertEquals( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { + # Ensure the fields are as expected. + self.assertCountEqual( + actual.keys(), + ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.THREAD, + ), + ) + + # Check the values of each field. + self.assertEquals( + { "chunk": [ {"type": "m.reaction", "key": "a", "count": 2}, {"type": "m.reaction", "key": "b", "count": 1}, ] }, - RelationTypes.REFERENCE: { - "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] - }, - RelationTypes.THREAD: { - "count": 2, - "latest_event": { - "age": 100, - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "origin_server_ts": 1600, - "room_id": self.room, - "sender": self.user_id, - "type": "m.room.test", - "unsigned": {"age": 100}, - "user_id": self.user_id, + actual[RelationTypes.ANNOTATION], + ) + + self.assertEquals( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + actual[RelationTypes.REFERENCE], + ) + + self.assertEquals( + 2, + actual[RelationTypes.THREAD].get("count"), + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } }, + "event_id": thread_2, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "user_id": self.user_id, }, - }, + actual[RelationTypes.THREAD].get("latest_event"), + ) + + def _find_and_assert_event(events): + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + break + else: + raise AssertionError(f"Event {self.parent_id} not found in chunk") + assert_bundle(event["unsigned"].get("m.relations")) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["unsigned"].get("m.relations")) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, ) + self.assertEquals(200, channel.code, channel.json_body) + _find_and_assert_event(channel.json_body["chunk"]) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) + + # Request sync. + channel = self.make_request("GET", "/sync", access_token=self.user_token) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + _find_and_assert_event(room_timeline["events"]) + + # Note that /relations is tested separately in test_aggregation_get_event_for_thread + # since it needs different data configured. def test_aggregation_get_event_for_annotation(self): - """Test that annotations do not get bundled relations included + """Test that annotations do not get bundled aggregations included when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -549,7 +606,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) def test_aggregation_get_event_for_thread(self): - """Test that threads get bundled relations included when directly requested.""" + """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") self.assertEquals(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] -- cgit 1.5.1 From 8b4b153c9e86c04c7db8c74fde4b6a04becbc461 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 6 Dec 2021 17:59:50 +0100 Subject: Add admin API to get some information about federation status (#11407) --- changelog.d/11407.feature | 1 + docs/SUMMARY.md | 1 + docs/usage/administration/admin_api/federation.md | 114 ++++++ synapse/rest/admin/__init__.py | 6 + synapse/rest/admin/federation.py | 135 +++++++ synapse/storage/databases/main/transactions.py | 70 ++++ tests/rest/admin/test_federation.py | 456 ++++++++++++++++++++++ 7 files changed, 783 insertions(+) create mode 100644 changelog.d/11407.feature create mode 100644 docs/usage/administration/admin_api/federation.md create mode 100644 synapse/rest/admin/federation.py create mode 100644 tests/rest/admin/test_federation.py (limited to 'tests') diff --git a/changelog.d/11407.feature b/changelog.d/11407.feature new file mode 100644 index 0000000000..1d21bde98f --- /dev/null +++ b/changelog.d/11407.feature @@ -0,0 +1 @@ +Add admin API to get some information about federation status with remote servers. \ No newline at end of file diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 41c8f0fbc9..b05af6d690 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -65,6 +65,7 @@ - [Statistics](admin_api/statistics.md) - [Users](admin_api/user_admin_api.md) - [Server Version](admin_api/version_api.md) + - [Federation](usage/administration/admin_api/federation.md) - [Manhole](manhole.md) - [Monitoring](metrics-howto.md) - [Understanding Synapse Through Grafana Graphs](usage/administration/understanding_synapse_through_grafana_graphs.md) diff --git a/docs/usage/administration/admin_api/federation.md b/docs/usage/administration/admin_api/federation.md new file mode 100644 index 0000000000..8f9535f57b --- /dev/null +++ b/docs/usage/administration/admin_api/federation.md @@ -0,0 +1,114 @@ +# Federation API + +This API allows a server administrator to manage Synapse's federation with other homeservers. + +Note: This API is new, experimental and "subject to change". + +## List of destinations + +This API gets the current destination retry timing info for all remote servers. + +The list contains all the servers with which the server federates, +regardless of whether an error occurred or not. +If an error occurs, it may take up to 20 minutes for the error to be displayed here, +as a complete retry must have failed. + +The API is: + +A standard request with no filtering: + +``` +GET /_synapse/admin/v1/federation/destinations +``` + +A response body like the following is returned: + +```json +{ + "destinations":[ + { + "destination": "matrix.org", + "retry_last_ts": 1557332397936, + "retry_interval": 3000000, + "failure_ts": 1557329397936, + "last_successful_stream_ordering": null + } + ], + "total": 1 +} +``` + +To paginate, check for `next_token` and if present, call the endpoint again +with `from` set to the value of `next_token`. This will return a new page. + +If the endpoint does not return a `next_token` then there are no more destinations +to paginate through. + +**Parameters** + +The following query parameters are available: + +- `from` - Offset in the returned list. Defaults to `0`. +- `limit` - Maximum amount of destinations to return. Defaults to `100`. +- `order_by` - The method in which to sort the returned list of destinations. + Valid values are: + - `destination` - Destinations are ordered alphabetically by remote server name. + This is the default. + - `retry_last_ts` - Destinations are ordered by time of last retry attempt in ms. + - `retry_interval` - Destinations are ordered by how long until next retry in ms. + - `failure_ts` - Destinations are ordered by when the server started failing in ms. + - `last_successful_stream_ordering` - Destinations are ordered by the stream ordering + of the most recent successfully-sent PDU. +- `dir` - Direction of room order. Either `f` for forwards or `b` for backwards. Setting + this value to `b` will reverse the above sort order. Defaults to `f`. + +*Caution:* The database only has an index on the column `destination`. +This means that if a different sort order is used, +this can cause a large load on the database, especially for large environments. + +**Response** + +The following fields are returned in the JSON response body: + +- `destinations` - An array of objects, each containing information about a destination. + Destination objects contain the following fields: + - `destination` - string - Name of the remote server to federate. + - `retry_last_ts` - integer - The last time Synapse tried and failed to reach the + remote server, in ms. This is `0` if the last attempt to communicate with the + remote server was successful. + - `retry_interval` - integer - How long since the last time Synapse tried to reach + the remote server before trying again, in ms. This is `0` if no further retrying occuring. + - `failure_ts` - nullable integer - The first time Synapse tried and failed to reach the + remote server, in ms. This is `null` if communication with the remote server has never failed. + - `last_successful_stream_ordering` - nullable integer - The stream ordering of the most + recent successfully-sent [PDU](understanding_synapse_through_grafana_graphs.md#federation) + to this destination, or `null` if this information has not been tracked yet. +- `next_token`: string representing a positive integer - Indication for pagination. See above. +- `total` - integer - Total number of destinations. + +# Destination Details API + +This API gets the retry timing info for a specific remote server. + +The API is: + +``` +GET /_synapse/admin/v1/federation/destinations/ +``` + +A response body like the following is returned: + +```json +{ + "destination": "matrix.org", + "retry_last_ts": 1557332397936, + "retry_interval": 3000000, + "failure_ts": 1557329397936, + "last_successful_stream_ordering": null +} +``` + +**Response** + +The response fields are the same like in the `destinations` array in +[List of destinations](#list-of-destinations) response. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c51a029bf3..c499afd4be 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -40,6 +40,10 @@ from synapse.rest.admin.event_reports import ( EventReportDetailRestServlet, EventReportsRestServlet, ) +from synapse.rest.admin.federation import ( + DestinationsRestServlet, + ListDestinationsRestServlet, +) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.registration_tokens import ( @@ -261,6 +265,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) + DestinationsRestServlet(hs).register(http_server) + ListDestinationsRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if hs.config.worker.worker_app is None: diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py new file mode 100644 index 0000000000..744687be35 --- /dev/null +++ b/synapse/rest/admin/federation.py @@ -0,0 +1,135 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.storage.databases.main.transactions import DestinationSortOrder +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ListDestinationsRestServlet(RestServlet): + """Get request to list all destinations. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations?from=0&limit=10 + + returns: + 200 OK with list of destinations if success otherwise an error. + + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `destination` can be used to filter by destination. + The parameter `order_by` can be used to order the result. + """ + + PATTERNS = admin_patterns("/federation/destinations$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + destination = parse_string(request, "destination") + + order_by = parse_string( + request, + "order_by", + default=DestinationSortOrder.DESTINATION.value, + allowed_values=[dest.value for dest in DestinationSortOrder], + ) + + direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) + + destinations, total = await self._store.get_destinations_paginate( + start, limit, destination, order_by, direction + ) + response = {"destinations": destinations, "total": total} + if (start + limit) < total: + response["next_token"] = str(start + len(destinations)) + + return HTTPStatus.OK, response + + +class DestinationsRestServlet(RestServlet): + """Get details of a destination. + This needs user to have administrator access in Synapse. + + GET /_synapse/admin/v1/federation/destinations/ + + returns: + 200 OK with details of a destination if success otherwise an error. + """ + + PATTERNS = admin_patterns("/federation/destinations/(?P[^/]+)$") + + def __init__(self, hs: "HomeServer"): + self._auth = hs.get_auth() + self._store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, destination: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + destination_retry_timings = await self._store.get_destination_retry_timings( + destination + ) + + if not destination_retry_timings: + raise NotFoundError("Unknown destination") + + last_successful_stream_ordering = ( + await self._store.get_destination_last_successful_stream_ordering( + destination + ) + ) + + response = { + "destination": destination, + "failure_ts": destination_retry_timings.failure_ts, + "retry_last_ts": destination_retry_timings.retry_last_ts, + "retry_interval": destination_retry_timings.retry_interval, + "last_successful_stream_ordering": last_successful_stream_ordering, + } + + return HTTPStatus.OK, response diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index d7dc1f73ac..1622822552 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,6 +14,7 @@ import logging from collections import namedtuple +from enum import Enum from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr @@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple( ) +class DestinationSortOrder(Enum): + """Enum to define the sorting method used when returning destinations.""" + + DESTINATION = "destination" + RETRY_LAST_TS = "retry_last_ts" + RETTRY_INTERVAL = "retry_interval" + FAILURE_TS = "failure_ts" + LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class DestinationRetryTimings: """The current destination retry timing info for a remote server.""" @@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): destinations = [row[0] for row in txn] return destinations + + async def get_destinations_paginate( + self, + start: int, + limit: int, + destination: Optional[str] = None, + order_by: str = DestinationSortOrder.DESTINATION.value, + direction: str = "f", + ) -> Tuple[List[JsonDict], int]: + """Function to retrieve a paginated list of destinations. + This will return a json list of destinations and the + total number of destinations matching the filter criteria. + + Args: + start: start number to begin the query from + limit: number of rows to retrieve + destination: search string in destination + order_by: the sort order of the returned list + direction: sort ascending or descending + Returns: + A tuple of a list of mappings from destination to information + and a count of total destinations. + """ + + def get_destinations_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: + order_by_column = DestinationSortOrder(order_by).value + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + args = [] + where_statement = "" + if destination: + args.extend(["%" + destination.lower() + "%"]) + where_statement = "WHERE LOWER(destination) LIKE ?" + + sql_base = f"FROM destinations {where_statement} " + sql = f"SELECT COUNT(*) as total_destinations {sql_base}" + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = f""" + SELECT destination, retry_last_ts, retry_interval, failure_ts, + last_successful_stream_ordering + {sql_base} + ORDER BY {order_by_column} {order}, destination ASC + LIMIT ? OFFSET ? + """ + txn.execute(sql, args + [limit, start]) + destinations = self.db_pool.cursor_to_dict(txn) + return destinations, count + + return await self.db_pool.runInteraction( + "get_destinations_paginate_txn", get_destinations_paginate_txn + ) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py new file mode 100644 index 0000000000..5188499ef2 --- /dev/null +++ b/tests/rest/admin/test_federation.py @@ -0,0 +1,456 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus +from typing import List, Optional + +from parameterized import parameterized + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.types import JsonDict + +from tests import unittest + + +class FederationTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.store = hs.get_datastore() + self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.url = "/_synapse/admin/v1/federation/destinations" + + @parameterized.expand( + [ + ("/_synapse/admin/v1/federation/destinations",), + ("/_synapse/admin/v1/federation/destinations/dummy",), + ] + ) + def test_requester_is_no_admin(self, url: str): + """ + If the user is not a server admin, an error 403 is returned. + """ + + self.register_user("user", "pass", admin=False) + other_user_tok = self.login("user", "pass") + + channel = self.make_request( + "GET", + url, + content={}, + access_token=other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + + # negative limit + channel = self.make_request( + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # unkown order_by + channel = self.make_request( + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid destination + channel = self.make_request( + "GET", + self.url + "/dummy", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of destinations with limit + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 5) + self.assertEqual(channel.json_body["next_token"], "5") + self._check_fields(channel.json_body["destinations"]) + + def test_from(self): + """ + Testing list of destinations with a defined starting point (from) + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["destinations"]) + + def test_limit_and_from(self): + """ + Testing list of destinations with a defined starting point and limit + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(channel.json_body["next_token"], "15") + self.assertEqual(len(channel.json_body["destinations"]), 10) + self._check_fields(channel.json_body["destinations"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_destinations = 20 + self._create_destinations(number_destinations) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), number_destinations) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), number_destinations) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 19) + self.assertEqual(channel.json_body["next_token"], "19") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], number_destinations) + self.assertEqual(len(channel.json_body["destinations"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_list_all_destinations(self): + """ + List all destinations. + """ + number_destinations = 5 + self._create_destinations(number_destinations) + + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(number_destinations, len(channel.json_body["destinations"])) + self.assertEqual(number_destinations, channel.json_body["total"]) + + # Check that all fields are available + self._check_fields(channel.json_body["destinations"]) + + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + def _order_test( + expected_destination_list: List[str], + order_by: Optional[str], + dir: Optional[str] = None, + ): + """Request the list of destinations in a certain order. + Assert that order is what we expect + + Args: + expected_destination_list: The list of user_id in the order + we expect to get back from the server + order_by: The type of ordering to give the server + dir: The direction of ordering to give the server + """ + + url = f"{self.url}?" + if order_by is not None: + url += f"order_by={order_by}&" + if dir is not None and dir in ("b", "f"): + url += f"dir={dir}" + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_destination_list)) + + returned_order = [ + row["destination"] for row in channel.json_body["destinations"] + ] + self.assertEqual(expected_destination_list, returned_order) + self._check_fields(channel.json_body["destinations"]) + + # create destinations + dest = [ + ("sub-a.example.com", 100, 300, 200, 300), + ("sub-b.example.com", 200, 200, 100, 100), + ("sub-c.example.com", 300, 100, 300, 200), + ] + for ( + destination, + failure_ts, + retry_last_ts, + retry_interval, + last_successful_stream_ordering, + ) in dest: + self.get_success( + self.store.set_destination_retry_timings( + destination, failure_ts, retry_last_ts, retry_interval + ) + ) + self.get_success( + self.store.set_destination_last_successful_stream_ordering( + destination, last_successful_stream_ordering + ) + ) + + # order by default (destination) + _order_test([dest[0][0], dest[1][0], dest[2][0]], None) + _order_test([dest[0][0], dest[1][0], dest[2][0]], None, "f") + _order_test([dest[2][0], dest[1][0], dest[0][0]], None, "b") + + # order by destination + _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination") + _order_test([dest[0][0], dest[1][0], dest[2][0]], "destination", "f") + _order_test([dest[2][0], dest[1][0], dest[0][0]], "destination", "b") + + # order by failure_ts + _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts") + _order_test([dest[0][0], dest[1][0], dest[2][0]], "failure_ts", "f") + _order_test([dest[2][0], dest[1][0], dest[0][0]], "failure_ts", "b") + + # order by retry_last_ts + _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts") + _order_test([dest[2][0], dest[1][0], dest[0][0]], "retry_last_ts", "f") + _order_test([dest[0][0], dest[1][0], dest[2][0]], "retry_last_ts", "b") + + # order by retry_interval + _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval") + _order_test([dest[1][0], dest[0][0], dest[2][0]], "retry_interval", "f") + _order_test([dest[2][0], dest[0][0], dest[1][0]], "retry_interval", "b") + + # order by last_successful_stream_ordering + _order_test( + [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering" + ) + _order_test( + [dest[1][0], dest[2][0], dest[0][0]], "last_successful_stream_ordering", "f" + ) + _order_test( + [dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b" + ) + + def test_search_term(self): + """Test that searching for a destination works correctly""" + + def _search_test( + expected_destination: Optional[str], + search_term: str, + ): + """Search for a destination and check that the returned destinationis a match + + Args: + expected_destination: The room_id expected to be returned by the API. + Set to None to expect zero results for the search + search_term: The term to search for room names with + """ + url = f"{self.url}?destination={search_term}" + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Check that destinations were returned + self.assertTrue("destinations" in channel.json_body) + self._check_fields(channel.json_body["destinations"]) + destinations = channel.json_body["destinations"] + + # Check that the expected number of destinations were returned + expected_destination_count = 1 if expected_destination else 0 + self.assertEqual(len(destinations), expected_destination_count) + self.assertEqual(channel.json_body["total"], expected_destination_count) + + if expected_destination: + # Check that the first returned destination is correct + self.assertEqual(expected_destination, destinations[0]["destination"]) + + number_destinations = 3 + self._create_destinations(number_destinations) + + # Test searching + _search_test("sub0.example.com", "0") + _search_test("sub0.example.com", "sub0") + + _search_test("sub1.example.com", "1") + _search_test("sub1.example.com", "1.") + + # Test case insensitive + _search_test("sub0.example.com", "SUB0") + + _search_test(None, "foo") + _search_test(None, "bar") + + def test_get_single_destination(self): + """ + Get one specific destinations. + """ + self._create_destinations(5) + + channel = self.make_request( + "GET", + self.url + "/sub0.example.com", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual("sub0.example.com", channel.json_body["destination"]) + + # Check that all fields are available + # convert channel.json_body into a List + self._check_fields([channel.json_body]) + + def _create_destinations(self, number_destinations: int): + """Create a number of destinations + + Args: + number_destinations: Number of destinations to be created + """ + for i in range(0, number_destinations): + dest = f"sub{i}.example.com" + self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50)) + self.get_success( + self.store.set_destination_last_successful_stream_ordering(dest, 100) + ) + + def _check_fields(self, content: List[JsonDict]): + """Checks that the expected destination attributes are present in content + + Args: + content: List that is checked for content + """ + for c in content: + self.assertIn("destination", c) + self.assertIn("retry_last_ts", c) + self.assertIn("retry_interval", c) + self.assertIn("failure_ts", c) + self.assertIn("last_successful_stream_ordering", c) -- cgit 1.5.1 From a15a893df8428395df7cb95b729431575001c38a Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 6 Dec 2021 18:43:06 +0100 Subject: Save the OIDC session ID (sid) with the device on login (#11482) As a step towards allowing back-channel logout for OIDC. --- changelog.d/11482.misc | 1 + synapse/handlers/auth.py | 34 +++++- synapse/handlers/device.py | 8 ++ synapse/handlers/oidc.py | 58 +++++---- synapse/handlers/register.py | 15 ++- synapse/handlers/sso.py | 4 + synapse/module_api/__init__.py | 2 + synapse/replication/http/login.py | 8 ++ synapse/rest/client/login.py | 7 +- synapse/storage/databases/main/devices.py | 50 +++++++- .../delta/65/11_devices_auth_provider_session.sql | 27 +++++ tests/handlers/test_auth.py | 6 +- tests/handlers/test_cas.py | 40 +++++- tests/handlers/test_oidc.py | 135 ++++++++++++++++++--- tests/handlers/test_saml.py | 40 +++++- 15 files changed, 370 insertions(+), 65 deletions(-) create mode 100644 changelog.d/11482.misc create mode 100644 synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql (limited to 'tests') diff --git a/changelog.d/11482.misc b/changelog.d/11482.misc new file mode 100644 index 0000000000..e78662988f --- /dev/null +++ b/changelog.d/11482.misc @@ -0,0 +1 @@ +Save the OpenID Connect session ID on login. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 4d9c4e5834..61607cf2ba 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -39,6 +39,7 @@ import attr import bcrypt import pymacaroons import unpaddedbase64 +from pymacaroons.exceptions import MacaroonVerificationFailedException from twisted.web.server import Request @@ -182,8 +183,11 @@ class LoginTokenAttributes: user_id = attr.ib(type=str) - # the SSO Identity Provider that the user authenticated with, to get this token auth_provider_id = attr.ib(type=str) + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id = attr.ib(type=Optional[str]) + """The session ID advertised by the SSO Identity Provider.""" class AuthHandler: @@ -1650,6 +1654,7 @@ class AuthHandler: client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, new_user: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> None: """Having figured out a mxid for this user, complete the HTTP request @@ -1665,6 +1670,7 @@ class AuthHandler: during successful login. Must be JSON serializable. new_user: True if we should use wording appropriate to a user who has just registered. + auth_provider_session_id: The session ID from the SSO IdP received during login. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1685,6 +1691,7 @@ class AuthHandler: extra_attributes, new_user=new_user, user_profile_data=profile, + auth_provider_session_id=auth_provider_session_id, ) def _complete_sso_login( @@ -1696,6 +1703,7 @@ class AuthHandler: extra_attributes: Optional[JsonDict] = None, new_user: bool = False, user_profile_data: Optional[ProfileInfo] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ The synchronous portion of complete_sso_login. @@ -1717,7 +1725,9 @@ class AuthHandler: # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( - registered_user_id, auth_provider_id=auth_provider_id + registered_user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) # Append the login token to the original redirect URL (i.e. with its query @@ -1822,6 +1832,7 @@ class MacaroonGenerator: self, user_id: str, auth_provider_id: str, + auth_provider_session_id: Optional[str] = None, duration_in_ms: int = (2 * 60 * 1000), ) -> str: macaroon = self._generate_base_macaroon(user_id) @@ -1830,6 +1841,10 @@ class MacaroonGenerator: expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) + if auth_provider_session_id is not None: + macaroon.add_first_party_caveat( + "auth_provider_session_id = %s" % (auth_provider_session_id,) + ) return macaroon.serialize() def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: @@ -1851,15 +1866,28 @@ class MacaroonGenerator: user_id = get_value_from_macaroon(macaroon, "user_id") auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") + auth_provider_session_id: Optional[str] = None + try: + auth_provider_session_id = get_value_from_macaroon( + macaroon, "auth_provider_session_id" + ) + except MacaroonVerificationFailedException: + pass + v = pymacaroons.Verifier() v.satisfy_exact("gen = 1") v.satisfy_exact("type = login") v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) + v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) satisfy_expiry(v, self.hs.get_clock().time_msec) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) + return LoginTokenAttributes( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 68b446eb66..82ee11e921 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: str, device_id: Optional[str], initial_device_display_name: Optional[str] = None, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> str: """ If the given device has not been registered, register it with the @@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id: @user:id device_id: device id supplied by client initial_device_display_name: device display name from client + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from the SSO IdP. Returns: device id (generated if none was supplied) """ @@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [device_id]) @@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler): user_id=user_id, device_id=new_device_id, initial_device_display_name=initial_device_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if new_device: await self.notify_device_update(user_id, [new_device_id]) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 3665d91513..deb3539751 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -23,7 +23,7 @@ from authlib.common.security import generate_token from authlib.jose import JsonWebToken, jwt from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri -from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo +from authlib.oidc.core import CodeIDToken, UserInfo from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url from jinja2 import Environment, Template from pymacaroons.exceptions import ( @@ -117,7 +117,8 @@ class OidcHandler: for idp_id, p in self._providers.items(): try: await p.load_metadata() - await p.load_jwks() + if not p._uses_userinfo: + await p.load_jwks() except Exception as e: raise Exception( "Error while initialising OIDC provider %r" % (idp_id,) @@ -498,10 +499,6 @@ class OidcProvider: return await self._jwks.get() async def _load_jwks(self) -> JWKS: - if self._uses_userinfo: - # We're not using jwt signing, return an empty jwk set - return {"keys": []} - metadata = await self.load_metadata() # Load the JWKS using the `jwks_uri` metadata. @@ -663,7 +660,7 @@ class OidcProvider: return UserInfo(resp) - async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: """Return an instance of UserInfo from token's ``id_token``. Args: @@ -673,7 +670,7 @@ class OidcProvider: request. This value should match the one inside the token. Returns: - An object representing the user. + The decoded claims in the ID token. """ metadata = await self.load_metadata() claims_params = { @@ -684,9 +681,6 @@ class OidcProvider: # If we got an `access_token`, there should be an `at_hash` claim # in the `id_token` that we can check against. claims_params["access_token"] = token["access_token"] - claims_cls = CodeIDToken - else: - claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) jwt = JsonWebToken(alg_values) @@ -703,7 +697,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -713,7 +707,7 @@ class OidcProvider: claims = jwt.decode( id_token, key=jwk_set, - claims_cls=claims_cls, + claims_cls=CodeIDToken, claims_options=claim_options, claims_params=claims_params, ) @@ -721,7 +715,8 @@ class OidcProvider: logger.debug("Decoded id_token JWT %r; validating", claims) claims.validate(leeway=120) # allows 2 min of clock skew - return UserInfo(claims) + + return claims async def handle_redirect_request( self, @@ -837,8 +832,22 @@ class OidcProvider: logger.debug("Successfully obtained OAuth2 token data: %r", token) - # Now that we have a token, get the userinfo, either by decoding the - # `id_token` or by fetching the `userinfo_endpoint`. + # If there is an id_token, it should be validated, regardless of the + # userinfo endpoint is used or not. + if token.get("id_token") is not None: + try: + id_token = await self._parse_id_token(token, nonce=session_data.nonce) + sid = id_token.get("sid") + except Exception as e: + logger.exception("Invalid id_token") + self._sso_handler.render_error(request, "invalid_token", str(e)) + return + else: + id_token = None + sid = None + + # Now that we have a token, get the userinfo either from the `id_token` + # claims or by fetching the `userinfo_endpoint`. if self._uses_userinfo: try: userinfo = await self._fetch_userinfo(token) @@ -846,13 +855,14 @@ class OidcProvider: logger.exception("Could not fetch userinfo") self._sso_handler.render_error(request, "fetch_error", str(e)) return + elif id_token is not None: + userinfo = UserInfo(id_token) else: - try: - userinfo = await self._parse_id_token(token, nonce=session_data.nonce) - except Exception as e: - logger.exception("Invalid id_token") - self._sso_handler.render_error(request, "invalid_token", str(e)) - return + logger.error("Missing id_token in token response") + self._sso_handler.render_error( + request, "invalid_token", "Missing id_token in token response" + ) + return # first check if we're doing a UIA if session_data.ui_auth_session_id: @@ -884,7 +894,7 @@ class OidcProvider: # Call the mapper to register/login the user try: await self._complete_oidc_login( - userinfo, token, request, session_data.client_redirect_url + userinfo, token, request, session_data.client_redirect_url, sid ) except MappingException as e: logger.exception("Could not map user") @@ -896,6 +906,7 @@ class OidcProvider: token: Token, request: SynapseRequest, client_redirect_url: str, + sid: Optional[str], ) -> None: """Given a UserInfo response, complete the login flow @@ -1008,6 +1019,7 @@ class OidcProvider: oidc_response_to_user_attributes, grandfather_existing_users, extra_attributes, + auth_provider_session_id=sid, ) def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b14ddd8267..f08a516a75 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -746,6 +746,7 @@ class RegistrationHandler: is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. @@ -756,9 +757,9 @@ class RegistrationHandler: device_id: The device ID to check, or None to generate a new one. initial_display_name: An optional display name for the device. is_guest: Whether this is a guest account - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: Whether it should also issue a refresh token + auth_provider_session_id: The session ID received during login from the SSO IdP. Returns: Tuple of device ID, access token, access token expiration time and refresh token """ @@ -769,6 +770,8 @@ class RegistrationHandler: is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) login_counter.labels( @@ -791,6 +794,8 @@ class RegistrationHandler: is_guest: bool = False, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> LoginDict: """Helper for register_device @@ -822,7 +827,11 @@ class RegistrationHandler: refresh_token_id = None registered_device_id = await self.device_handler.check_device_registered( - user_id, device_id, initial_display_name + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) if is_guest: assert access_token_expiry is None diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 49fde01cf0..65c27bc64a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -365,6 +365,7 @@ class SsoHandler: sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], extra_login_attributes: Optional[JsonDict] = None, + auth_provider_session_id: Optional[str] = None, ) -> None: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -415,6 +416,8 @@ class SsoHandler: extra_login_attributes: An optional dictionary of extra attributes to be provided to the client in the login response. + auth_provider_session_id: An optional session ID from the IdP. + Raises: MappingException if there was a problem mapping the response to a user. RedirectException: if the mapping provider needs to redirect the user @@ -490,6 +493,7 @@ class SsoHandler: client_redirect_url, extra_login_attributes, new_user=new_user, + auth_provider_session_id=auth_provider_session_id, ) async def _call_attribute_mapper( diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index a8154168be..6bfb4b8d1b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -626,6 +626,7 @@ class ModuleApi: user_id: str, duration_in_ms: int = (2 * 60 * 1000), auth_provider_id: str = "", + auth_provider_session_id: Optional[str] = None, ) -> str: """Generate a login token suitable for m.login.token authentication @@ -643,6 +644,7 @@ class ModuleApi: return self._hs.get_macaroon_generator().generate_short_term_login_token( user_id, auth_provider_id, + auth_provider_session_id, duration_in_ms, ) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 0db419ea57..daacc34cea 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -46,6 +46,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost, should_issue_refresh_token, + auth_provider_id, + auth_provider_session_id, ): """ Args: @@ -63,6 +65,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "is_guest": is_guest, "is_appservice_ghost": is_appservice_ghost, "should_issue_refresh_token": should_issue_refresh_token, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, } async def _handle_request(self, request, user_id): @@ -73,6 +77,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] should_issue_refresh_token = content["should_issue_refresh_token"] + auth_provider_id = content["auth_provider_id"] + auth_provider_session_id = content["auth_provider_session_id"] res = await self.registration_handler.register_device_inner( user_id, @@ -81,6 +87,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): is_guest, is_appservice_ghost=is_appservice_ghost, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, ) return 200, res diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index a66ee4fb3d..1b23fa18cf 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -303,6 +303,7 @@ class LoginRestServlet(RestServlet): ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, + auth_provider_session_id: Optional[str] = None, ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -318,10 +319,10 @@ class LoginRestServlet(RestServlet): create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). + auth_provider_id: The SSO IdP the user used, if any. should_issue_refresh_token: True if this login should issue a refresh token alongside the access token. + auth_provider_session_id: The session ID got during login from the SSO IdP. Returns: result: Dictionary of account information after successful login. @@ -354,6 +355,7 @@ class LoginRestServlet(RestServlet): initial_display_name, auth_provider_id=auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=auth_provider_session_id, ) result = LoginResponse( @@ -399,6 +401,7 @@ class LoginRestServlet(RestServlet): self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, should_issue_refresh_token=should_issue_refresh_token, + auth_provider_session_id=res.auth_provider_session_id, ) async def _do_jwt_login( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9ccc66e589..d5a4a661cd 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore): return {d["device_id"]: d for d in devices} + async def get_devices_by_auth_provider_session_id( + self, auth_provider_id: str, auth_provider_session_id: str + ) -> List[Dict[str, Any]]: + """Retrieve the list of devices associated with a SSO IdP session ID. + + Args: + auth_provider_id: The SSO IdP ID as defined in the server config + auth_provider_session_id: The session ID within the IdP + Returns: + A list of dicts containing the device_id and the user_id of each device + """ + return await self.db_pool.simple_select_list( + table="device_auth_providers", + keyvalues={ + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + retcols=("user_id", "device_id"), + desc="get_devices_by_auth_provider_session_id", + ) + @trace async def get_device_updates_by_remote( self, destination: str, from_stream_id: int, limit: int @@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def store_device( - self, user_id: str, device_id: str, initial_device_display_name: Optional[str] + self, + user_id: str, + device_id: str, + initial_device_display_name: Optional[str], + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, ) -> bool: """Ensure the given device is known; add it to the store if not @@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: id of device initial_device_display_name: initial displayname of the device. Ignored if device exists. + auth_provider_id: The SSO IdP the user used, if any. + auth_provider_session_id: The session ID (sid) got from a OIDC login. Returns: Whether the device was inserted or an existing device existed with that ID. @@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) + if auth_provider_id and auth_provider_session_id: + await self.db_pool.simple_insert( + "device_auth_providers", + values={ + "user_id": user_id, + "device_id": device_id, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="store_device_auth_provider", + ) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: @@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) + self.db_pool.simple_delete_many_txn( + txn, + table="device_auth_providers", + column="device_id", + values=device_ids, + keyvalues={"user_id": user_id}, + ) + await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) diff --git a/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql new file mode 100644 index 0000000000..a65bfb520d --- /dev/null +++ b/synapse/storage/schema/main/delta/65/11_devices_auth_provider_session.sql @@ -0,0 +1,27 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track the auth provider used by each login as well as the session ID +CREATE TABLE device_auth_providers ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + auth_provider_id TEXT NOT NULL, + auth_provider_session_id TEXT NOT NULL +); + +CREATE INDEX device_auth_providers_devices + ON device_auth_providers (user_id, device_id); +CREATE INDEX device_auth_providers_sessions + ON device_auth_providers (auth_provider_id, auth_provider_session_id); diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 72e176da75..03b8b8615c 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -71,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_gives_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) self.assertEqual(self.user1, res.user_id) @@ -94,7 +94,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) macaroon = pymacaroons.Macaroon.deserialize(token) @@ -213,6 +213,6 @@ class AuthTestCase(unittest.HomeserverTestCase): def _get_macaroon(self): token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", 5000 + self.user1, "", duration_in_ms=5000 ) return pymacaroons.Macaroon.deserialize(token) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index b625995d12..8705ff8943 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -66,7 +66,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) def test_map_cas_user_to_existing_user(self): @@ -89,7 +95,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -98,7 +110,13 @@ class CasHandlerTestCase(HomeserverTestCase): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=False + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_cas_user_to_invalid_localpart(self): @@ -116,7 +134,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True + "@f=c3=b6=c3=b6:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -160,7 +184,13 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "cas", request, "redirect_uri", None, new_user=True + "@test_user:test", + "cas", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a25c89bd5b..cfe3de5266 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -252,13 +252,6 @@ class OidcHandlerTestCase(HomeserverTestCase): with patch.object(self.provider, "load_metadata", patched_load_metadata): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) - # Return empty key set if JWKS are not used - self.provider._scopes = [] # not asking the openid scope - self.http_client.get_json.reset_mock() - jwks = self.get_success(self.provider.load_jwks(force=True)) - self.http_client.get_json.assert_not_called() - self.assertEqual(jwks, {"keys": []}) - @override_config({"oidc_config": DEFAULT_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" @@ -455,7 +448,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=True + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=True, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -482,17 +481,58 @@ class OidcHandlerTestCase(HomeserverTestCase): self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.provider._scopes = [] # do not ask the "openid" scope + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + } + self.provider._exchange_code = simple_async_mock(return_value=token) self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, "oidc", request, client_redirect_url, None, new_user=False + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=None, ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() + # With an ID token, userinfo fetching and sid in the ID token + self.provider._user_profile_method = "userinfo_endpoint" + token = { + "type": "bearer", + "access_token": "access_token", + "id_token": "id_token", + } + id_token = { + "sid": "abcdefgh", + } + self.provider._parse_id_token = simple_async_mock(return_value=id_token) + self.provider._exchange_code = simple_async_mock(return_value=token) + auth_handler.complete_sso_login.reset_mock() + self.provider._fetch_userinfo.reset_mock() + self.get_success(self.handler.handle_oidc_callback(request)) + + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, + "oidc", + request, + client_redirect_url, + None, + new_user=False, + auth_provider_session_id=id_token["sid"], + ) + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.provider._fetch_userinfo.assert_called_once_with(token) + self.render_error.assert_not_called() + # Handle userinfo fetching error self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) @@ -776,6 +816,7 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url, {"phone": "1234567"}, new_user=True, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -790,7 +831,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "oidc", ANY, ANY, None, new_user=True + "@test_user:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -801,7 +848,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True + "@test_user_2:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -838,14 +891,26 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -860,7 +925,13 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), "oidc", ANY, ANY, None, new_user=False + user.to_string(), + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -896,7 +967,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False + "@TEST_USER_2:test", + "oidc", + ANY, + ANY, + None, + new_user=False, + auth_provider_session_id=None, ) @override_config({"oidc_config": DEFAULT_CONFIG}) @@ -934,7 +1011,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "oidc", ANY, ANY, None, new_user=True + "@test_user1:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -1018,7 +1101,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1043,7 +1132,13 @@ class OidcHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@tester:test", "oidc", ANY, ANY, None, new_user=True + "@tester:test", + "oidc", + ANY, + ANY, + None, + new_user=True, + auth_provider_session_id=None, ) @override_config( @@ -1156,7 +1251,7 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={}) + provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 8cfc184fef..50551aa6e3 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -130,7 +130,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -156,7 +162,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) # Subsequent calls should map to the same mxid. @@ -165,7 +177,13 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "", None, new_user=False + "@test_user:test", + "saml", + request, + "", + None, + new_user=False, + auth_provider_session_id=None, ) def test_map_saml_response_to_invalid_localpart(self): @@ -213,7 +231,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", "saml", request, "", None, new_user=True + "@test_user1:test", + "saml", + request, + "", + None, + new_user=True, + auth_provider_session_id=None, ) auth_handler.complete_sso_login.reset_mock() @@ -309,7 +333,13 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", "saml", request, "redirect_uri", None, new_user=True + "@test_user:test", + "saml", + request, + "redirect_uri", + None, + new_user=True, + auth_provider_session_id=None, ) -- cgit 1.5.1 From 2f053f3f82ca174cc1c858c75afffae51af8ce0d Mon Sep 17 00:00:00 2001 From: reivilibre Date: Mon, 6 Dec 2021 19:11:43 +0000 Subject: Stabilise support for MSC2918 refresh tokens as they have now been merged into the Matrix specification. (#11435) --- changelog.d/11435.feature | 1 + docs/sample_config.yaml | 38 ++++++++++++++++++++++++++++++++++++++ synapse/config/registration.py | 38 ++++++++++++++++++++++++++++++++++++++ synapse/rest/client/login.py | 29 +++++++++++++---------------- synapse/rest/client/register.py | 23 ++++++++++------------- tests/rest/client/test_auth.py | 30 +++++++++++++++--------------- 6 files changed, 115 insertions(+), 44 deletions(-) create mode 100644 changelog.d/11435.feature (limited to 'tests') diff --git a/changelog.d/11435.feature b/changelog.d/11435.feature new file mode 100644 index 0000000000..9e127fae3c --- /dev/null +++ b/changelog.d/11435.feature @@ -0,0 +1 @@ +Stabilise support for [MSC2918](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) refresh tokens as they have now been merged into the Matrix specification. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ae476d19ac..6696ed5d1e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1209,6 +1209,44 @@ oembed: # #session_lifetime: 24h +# Time that an access token remains valid for, if the session is +# using refresh tokens. +# For more information about refresh tokens, please see the manual. +# Note that this only applies to clients which advertise support for +# refresh tokens. +# +# Note also that this is calculated at login time and refresh time: +# changes are not applied to existing sessions until they are refreshed. +# +# By default, this is 5 minutes. +# +#refreshable_access_token_lifetime: 5m + +# Time that a refresh token remains valid for (provided that it is not +# exchanged for another one first). +# This option can be used to automatically log-out inactive sessions. +# Please see the manual for more information. +# +# Note also that this is calculated at login time and refresh time: +# changes are not applied to existing sessions until they are refreshed. +# +# By default, this is infinite. +# +#refresh_token_lifetime: 24h + +# Time that an access token remains valid for, if the session is NOT +# using refresh tokens. +# Please note that not all clients support refresh tokens, so setting +# this to a short value may be inconvenient for some users who will +# then be logged out frequently. +# +# Note also that this is calculated at login time: changes are not applied +# retrospectively to existing sessions for users that have already logged in. +# +# By default, this is infinite. +# +#nonrefreshable_access_token_lifetime: 24h + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 68a4985398..7a059c6dec 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -220,6 +220,44 @@ class RegistrationConfig(Config): # #session_lifetime: 24h + # Time that an access token remains valid for, if the session is + # using refresh tokens. + # For more information about refresh tokens, please see the manual. + # Note that this only applies to clients which advertise support for + # refresh tokens. + # + # Note also that this is calculated at login time and refresh time: + # changes are not applied to existing sessions until they are refreshed. + # + # By default, this is 5 minutes. + # + #refreshable_access_token_lifetime: 5m + + # Time that a refresh token remains valid for (provided that it is not + # exchanged for another one first). + # This option can be used to automatically log-out inactive sessions. + # Please see the manual for more information. + # + # Note also that this is calculated at login time and refresh time: + # changes are not applied to existing sessions until they are refreshed. + # + # By default, this is infinite. + # + #refresh_token_lifetime: 24h + + # Time that an access token remains valid for, if the session is NOT + # using refresh tokens. + # Please note that not all clients support refresh tokens, so setting + # this to a short value may be inconvenient for some users who will + # then be logged out frequently. + # + # Note also that this is calculated at login time: changes are not applied + # retrospectively to existing sessions for users that have already logged in. + # + # By default, this is infinite. + # + #nonrefreshable_access_token_lifetime: 24h + # The user must provide all of the below types of 3PID when registering. # #registrations_require_3pid: diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 1b23fa18cf..f9994658c4 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -72,7 +72,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "m.login.application_service" APPSERVICE_TYPE_UNSTABLE = "uk.half-shot.msc2778.login.application_service" - REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" + REFRESH_TOKEN_PARAM = "refresh_token" def __init__(self, hs: "HomeServer"): super().__init__() @@ -90,7 +90,7 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = ( + self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None ) @@ -163,17 +163,16 @@ class LoginRestServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: login_submission = parse_json_object_from_request(request) - if self._msc2918_enabled: - # Check if this login should also issue a refresh token, as per MSC2918 - should_issue_refresh_token = login_submission.get( - "org.matrix.msc2918.refresh_token", False - ) - if not isinstance(should_issue_refresh_token, bool): - raise SynapseError( - 400, "`org.matrix.msc2918.refresh_token` should be true or false." - ) - else: - should_issue_refresh_token = False + # Check to see if the client requested a refresh token. + client_requested_refresh_token = login_submission.get( + LoginRestServlet.REFRESH_TOKEN_PARAM, False + ) + if not isinstance(client_requested_refresh_token, bool): + raise SynapseError(400, "`refresh_token` should be true or false.") + + should_issue_refresh_token = ( + self._refresh_tokens_enabled and client_requested_refresh_token + ) try: if login_submission["type"] in ( @@ -463,9 +462,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: class RefreshTokenServlet(RestServlet): - PATTERNS = client_patterns( - "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True - ) + PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),) def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 11fd6cd24d..8b56c76aed 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -419,7 +419,7 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.registration.enable_registration - self._msc2918_enabled = ( + self._refresh_tokens_enabled = ( hs.config.registration.refreshable_access_token_lifetime is not None ) @@ -445,18 +445,15 @@ class RegisterRestServlet(RestServlet): f"Do not understand membership kind: {kind}", ) - if self._msc2918_enabled: - # Check if this registration should also issue a refresh token, as - # per MSC2918 - should_issue_refresh_token = body.get( - "org.matrix.msc2918.refresh_token", False - ) - if not isinstance(should_issue_refresh_token, bool): - raise SynapseError( - 400, "`org.matrix.msc2918.refresh_token` should be true or false." - ) - else: - should_issue_refresh_token = False + # Check if the clients wishes for this registration to issue a refresh + # token. + client_requested_refresh_tokens = body.get("refresh_token", False) + if not isinstance(client_requested_refresh_tokens, bool): + raise SynapseError(400, "`refresh_token` should be true or false.") + + should_issue_refresh_token = ( + self._refresh_tokens_enabled and client_requested_refresh_tokens + ) # Pull out the provided username and do basic sanity checks early since # the auth layer will store these in sessions. diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 7239e1a1b5..aa8ad6d2e1 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -520,7 +520,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): """ return self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": refresh_token}, ) @@ -557,7 +557,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): login_with_refresh = self.make_request( "POST", "/_matrix/client/r0/login", - {"org.matrix.msc2918.refresh_token": True, **body}, + {"refresh_token": True, **body}, ) self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) self.assertIn("refresh_token", login_with_refresh.json_body) @@ -588,7 +588,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "username": "test3", "password": self.user_pass, "auth": {"type": LoginType.DUMMY}, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, }, ) self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) @@ -603,7 +603,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -614,7 +614,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, 200, refresh_response.result) @@ -641,7 +641,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -655,7 +655,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual(refresh_response.code, 200, refresh_response.result) @@ -761,7 +761,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -811,7 +811,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -868,7 +868,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): "type": "m.login.password", "user": "test", "password": self.user_pass, - "org.matrix.msc2918.refresh_token": True, + "refresh_token": True, } login_response = self.make_request( "POST", @@ -880,7 +880,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This first refresh should work properly first_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -890,7 +890,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one as well, since the token in the first one was never used second_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -900,7 +900,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # This one should not, since the token from the first refresh is not valid anymore third_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": first_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -928,7 +928,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail fourth_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": login_response.json_body["refresh_token"]}, ) self.assertEqual( @@ -938,7 +938,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): # But refreshing from the last valid refresh token still works fifth_refresh_response = self.make_request( "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + "/_matrix/client/v1/refresh", {"refresh_token": second_refresh_response.json_body["refresh_token"]}, ) self.assertEqual( -- cgit 1.5.1 From 2d42e586a8c54be1a83643148358b1651c1ca666 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 7 Dec 2021 10:49:39 +0000 Subject: Fix the test breakage introduced by #11435 as a result of concurrent PRs (#11522) --- changelog.d/11522.feature | 1 + tests/rest/client/test_auth.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11522.feature (limited to 'tests') diff --git a/changelog.d/11522.feature b/changelog.d/11522.feature new file mode 100644 index 0000000000..9e127fae3c --- /dev/null +++ b/changelog.d/11522.feature @@ -0,0 +1 @@ +Stabilise support for [MSC2918](https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens) refresh tokens as they have now been merged into the Matrix specification. \ No newline at end of file diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index aa8ad6d2e1..72bbc87b4a 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -703,7 +703,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase): login_response1 = self.make_request( "POST", "/_matrix/client/r0/login", - {"org.matrix.msc2918.refresh_token": True, **body}, + {"refresh_token": True, **body}, ) self.assertEqual(login_response1.code, 200, login_response1.result) self.assertApproximates( -- cgit 1.5.1 From b1ecd19c5d19815b69e425d80f442bf2877cab76 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Dec 2021 11:37:54 +0000 Subject: Fix 'delete room' admin api to work on incomplete rooms (#11523) If, for some reason, we don't have the create event, we should still be able to purge a room. --- changelog.d/11523.feature | 1 + synapse/handlers/pagination.py | 3 --- synapse/handlers/room.py | 21 +++++++-------------- synapse/rest/admin/rooms.py | 3 --- tests/rest/admin/test_room.py | 42 +++++++++++++++++++++++++----------------- 5 files changed, 33 insertions(+), 37 deletions(-) create mode 100644 changelog.d/11523.feature (limited to 'tests') diff --git a/changelog.d/11523.feature b/changelog.d/11523.feature new file mode 100644 index 0000000000..ecac7f9db9 --- /dev/null +++ b/changelog.d/11523.feature @@ -0,0 +1 @@ +Extend the "delete room" admin api to work correctly on rooms which have previously been partially deleted. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index cd64142735..4f42438053 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -406,9 +406,6 @@ class PaginationHandler: force: set true to skip checking for joined users. """ with await self.pagination_lock.write(room_id): - # check we know about the room - await self.store.get_room_version_id(room_id) - # first check that we have no users in this room if not force: joined = await self.store.is_host_joined(room_id, self._server_name) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2bcdf32dcc..ead2198e14 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1535,20 +1535,13 @@ class RoomShutdownHandler: await self.store.block_room(room_id, requester_user_id) if not await self.store.get_room(room_id): - if block: - # We allow you to block an unknown room. - return { - "kicked_users": [], - "failed_to_kick_users": [], - "local_aliases": [], - "new_room_id": None, - } - else: - # But if you don't want to preventatively block another room, - # this function can't do anything useful. - raise NotFoundError( - "Cannot shut down room: unknown room id %s" % (room_id,) - ) + # if we don't know about the room, there is nothing left to do. + return { + "kicked_users": [], + "failed_to_kick_users": [], + "local_aliases": [], + "new_room_id": None, + } if new_room_user_id is not None: if not self.hs.is_mine_id(new_room_user_id): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 669ab44a45..829e86675a 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -106,9 +106,6 @@ class RoomRestV2Servlet(RestServlet): HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) ) - if not await self._store.get_room(room_id): - raise NotFoundError("Unknown room id %s" % (room_id,)) - delete_id = self._pagination_handler.start_shutdown_and_purge_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d3858e460d..22f9aa6234 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -83,7 +83,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def test_room_does_not_exist(self): """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + Check that unknown rooms/server return 200 """ url = "/_synapse/admin/v1/rooms/%s" % "!unknown:test" @@ -94,8 +94,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) def test_room_is_not_valid(self): """ @@ -508,27 +507,36 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - @parameterized.expand( - [ - ("DELETE", "/_synapse/admin/v2/rooms/%s"), - ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), - ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), - ] - ) - def test_room_does_not_exist(self, method: str, url: str): - """ - Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. + def test_room_does_not_exist(self): """ + Check that unknown rooms/server return 200 + This is important, as it allows incomplete vestiges of rooms to be cleared up + even if the create event/etc is missing. + """ + room_id = "!unknown:test" channel = self.make_request( - method, - url % "!unknown:test", + "DELETE", + f"/_synapse/admin/v2/rooms/{room_id}", content={}, access_token=self.admin_user_tok, ) - self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + f"/_synapse/admin/v2/rooms/{room_id}/delete_status", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id, channel.json_body["results"][0]["delete_id"]) @parameterized.expand( [ -- cgit 1.5.1 From 088d748f2cb51f03f3bcacc0fb3af1e0f9607737 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 7 Dec 2021 13:51:11 +0000 Subject: Revert "Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common` (#11505) (#11527) This reverts commit a77c36989785c0d5565ab9a1169f4f88e512ce8a. --- changelog.d/11527.misc | 1 + synapse/config/room_directory.py | 3 +- synapse/config/tls.py | 3 +- synapse/federation/federation_server.py | 3 +- synapse/push/push_rule_evaluator.py | 7 ++-- synapse/python_dependencies.py | 1 - synapse/util/__init__.py | 59 ++++++++++++++++++++++++++++++++- tests/util/test_glob_to_regex.py | 59 +++++++++++++++++++++++++++++++++ 8 files changed, 124 insertions(+), 12 deletions(-) create mode 100644 changelog.d/11527.misc create mode 100644 tests/util/test_glob_to_regex.py (limited to 'tests') diff --git a/changelog.d/11527.misc b/changelog.d/11527.misc new file mode 100644 index 0000000000..081eae317c --- /dev/null +++ b/changelog.d/11527.misc @@ -0,0 +1 @@ +Temporarily revert usage of `matrix-python-common`. diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 3c5e0f7ce7..57316c59b6 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -15,9 +15,8 @@ from typing import List -from matrix_common.regex import glob_to_regex - from synapse.types import JsonDict +from synapse.util import glob_to_regex from ._base import Config, ConfigError diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 3e235b57a7..4ca111618f 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -16,12 +16,11 @@ import logging import os from typing import List, Optional, Pattern -from matrix_common.regex import glob_to_regex - from OpenSSL import SSL, crypto from twisted.internet._sslverify import Certificate, trustRootFromCertificates from synapse.config._base import Config, ConfigError +from synapse.util import glob_to_regex logger = logging.getLogger(__name__) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4697a62c18..8e37e76206 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,7 +28,6 @@ from typing import ( Union, ) -from matrix_common.regex import glob_to_regex from prometheus_client import Counter, Gauge, Histogram from twisted.internet import defer @@ -67,7 +66,7 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict, get_domain_from_id -from synapse.util import json_decoder, unwrapFirstError +from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 659a53805d..7f68092ec5 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -17,10 +17,9 @@ import logging import re from typing import Any, Dict, List, Optional, Pattern, Tuple, Union -from matrix_common.regex import glob_to_regex, to_word_pattern - from synapse.events import EventBase from synapse.types import JsonDict, UserID +from synapse.util import glob_to_regex, re_word_boundary from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -185,7 +184,7 @@ class PushRuleEvaluatorForEvent: r = regex_cache.get((display_name, False, True), None) if not r: r1 = re.escape(display_name) - r1 = to_word_pattern(r1) + r1 = re_word_boundary(r1) r = re.compile(r1, flags=re.IGNORECASE) regex_cache[(display_name, False, True)] = r @@ -214,7 +213,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: try: r = regex_cache.get((glob, True, word_boundary), None) if not r: - r = glob_to_regex(glob, word_boundary=word_boundary) + r = glob_to_regex(glob, word_boundary) regex_cache[(glob, True, word_boundary)] = r return bool(r.search(value)) except re.error: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 386debd7db..7d26954244 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -87,7 +87,6 @@ REQUIREMENTS = [ # with the latest security patches. "cryptography>=3.4.7", "ijson>=3.1", - "matrix-common==1.0.0", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index f157132210..95f23e27b6 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -14,8 +14,9 @@ import json import logging +import re import typing -from typing import Any, Callable, Dict, Generator, Optional +from typing import Any, Callable, Dict, Generator, Optional, Pattern import attr from frozendict import frozendict @@ -34,6 +35,9 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) +_WILDCARD_RUN = re.compile(r"([\?\*]+)") + + def _reject_invalid_json(val: Any) -> None: """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) @@ -181,3 +185,56 @@ def log_failure( if not consumeErrors: return failure return None + + +def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern: + """Converts a glob to a compiled regex object. + + Args: + glob: pattern to match + word_boundary: If True, the pattern will be allowed to match at word boundaries + anywhere in the string. Otherwise, the pattern is anchored at the start and + end of the string. + + Returns: + compiled regex pattern + """ + + # Patterns with wildcards must be simplified to avoid performance cliffs + # - The glob `?**?**?` is equivalent to the glob `???*` + # - The glob `???*` is equivalent to the regex `.{3,}` + chunks = [] + for chunk in _WILDCARD_RUN.split(glob): + # No wildcards? re.escape() + if not _WILDCARD_RUN.match(chunk): + chunks.append(re.escape(chunk)) + continue + + # Wildcards? Simplify. + qmarks = chunk.count("?") + if "*" in chunk: + chunks.append(".{%d,}" % qmarks) + else: + chunks.append(".{%d}" % qmarks) + + res = "".join(chunks) + + if word_boundary: + res = re_word_boundary(res) + else: + # \A anchors at start of string, \Z at end of string + res = r"\A" + res + r"\Z" + + return re.compile(res, re.IGNORECASE) + + +def re_word_boundary(r: str) -> str: + """ + Adds word boundary characters to the start and end of an + expression to require that the match occur as a whole word, + but do so respecting the fact that strings starting or ending + with non-word characters will change word boundaries. + """ + # we can't use \b as it chokes on unicode. however \W seems to be okay + # as shorthand for [^0-9A-Za-z_]. + return r"(^|\W)%s(\W|$)" % (r,) diff --git a/tests/util/test_glob_to_regex.py b/tests/util/test_glob_to_regex.py new file mode 100644 index 0000000000..220accb92b --- /dev/null +++ b/tests/util/test_glob_to_regex.py @@ -0,0 +1,59 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.util import glob_to_regex + +from tests.unittest import TestCase + + +class GlobToRegexTestCase(TestCase): + def test_literal_match(self): + """patterns without wildcards should match""" + pat = glob_to_regex("foobaz") + self.assertTrue( + pat.match("FoobaZ"), "patterns should match and be case-insensitive" + ) + self.assertFalse( + pat.match("x foobaz"), "pattern should not match at word boundaries" + ) + + def test_wildcard_match(self): + pat = glob_to_regex("f?o*baz") + + self.assertTrue( + pat.match("FoobarbaZ"), + "* should match string and pattern should be case-insensitive", + ) + self.assertTrue(pat.match("foobaz"), "* should match 0 characters") + self.assertFalse(pat.match("fooxaz"), "the character after * must match") + self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters") + self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters") + + def test_multi_wildcard(self): + """patterns with multiple wildcards in a row should match""" + pat = glob_to_regex("**baz") + self.assertTrue(pat.match("agsgsbaz"), "** should match any string") + self.assertTrue(pat.match("baz"), "** should match the empty string") + self.assertEqual(pat.pattern, r"\A.{0,}baz\Z") + + pat = glob_to_regex("*?baz") + self.assertTrue(pat.match("agsgsbaz"), "*? should match any string") + self.assertTrue(pat.match("abaz"), "*? should match a single char") + self.assertFalse(pat.match("baz"), "*? should not match the empty string") + self.assertEqual(pat.pattern, r"\A.{1,}baz\Z") + + pat = glob_to_regex("a?*?*?baz") + self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars") + self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars") + self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars") + self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z") -- cgit 1.5.1 From d6fb96e056f79de220d8d59429d89a61498e9af3 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 7 Dec 2021 16:51:53 +0000 Subject: Fix case in `wait_for_background_updates` where `self.store` does not exist (#11331) Pull the DataStore from the HomeServer instance, which always exists. --- changelog.d/11331.misc | 1 + tests/unittest.py | 11 ++++------- 2 files changed, 5 insertions(+), 7 deletions(-) create mode 100644 changelog.d/11331.misc (limited to 'tests') diff --git a/changelog.d/11331.misc b/changelog.d/11331.misc new file mode 100644 index 0000000000..1ab3a6a975 --- /dev/null +++ b/changelog.d/11331.misc @@ -0,0 +1 @@ +A test helper (`wait_for_background_updates`) no longer depends on classes defining a `store` property. diff --git a/tests/unittest.py b/tests/unittest.py index eea0903f05..1431848367 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -331,16 +331,13 @@ class HomeserverTestCase(TestCase): time.sleep(0.01) def wait_for_background_updates(self) -> None: - """Block until all background database updates have completed. - - Note that callers must ensure there's a store property created on the - testcase. - """ + """Block until all background database updates have completed.""" + store = self.hs.get_datastore() while not self.get_success( - self.store.db_pool.updates.has_completed_background_updates() + store.db_pool.updates.has_completed_background_updates() ): self.get_success( - self.store.db_pool.updates.do_next_background_update(False), by=0.1 + store.db_pool.updates.do_next_background_update(False), by=0.1 ) def make_homeserver(self, reactor, clock): -- cgit 1.5.1 From 8541809cb952ebf0da2a95dd93eccd5644dab49d Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 8 Dec 2021 05:01:38 -0500 Subject: Send and handle cross-signing messages using the stable prefix. (#10520) --- changelog.d/10520.misc | 1 + synapse/handlers/e2e_keys.py | 8 ++++++-- synapse/storage/databases/main/devices.py | 4 +++- tests/federation/test_federation_sender.py | 5 +++-- 4 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10520.misc (limited to 'tests') diff --git a/changelog.d/10520.misc b/changelog.d/10520.misc new file mode 100644 index 0000000000..a911e165da --- /dev/null +++ b/changelog.d/10520.misc @@ -0,0 +1 @@ +Send and handle cross-signing messages using the stable prefix. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 60c11e3d21..b2554bda04 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -65,8 +65,12 @@ class E2eKeysHandler: else: # Only register this edu handler on master as it requires writing # device updates to the db - # - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + federation_registry.register_edu_handler( + "m.signing_key_update", + self._edu_updater.incoming_signing_key_update, + ) + # also handle the unstable version + # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( "org.matrix.signing_key_update", self._edu_updater.incoming_signing_key_update, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d5a4a661cd..838a2a6a3d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -274,7 +274,9 @@ class DeviceWorkerStore(SQLBaseStore): # add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec + results.append(("m.signing_key_update", result)) + # also send the unstable version + # FIXME: remove this when enough servers have upgraded results.append(("org.matrix.signing_key_update", result)) return now_stream_id, results diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b457dad6d2..b2376e2db9 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -266,7 +266,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) # expect signing key update edu - self.assertEqual(len(self.edus), 1) + self.assertEqual(len(self.edus), 2) + self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update") self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update") # sign the devices @@ -491,7 +492,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) -> None: """Check that the txn has an EDU with a signing key update.""" edus = txn["edus"] - self.assertEqual(len(edus), 1) + self.assertEqual(len(edus), 2) def generate_and_upload_device_signing_key( self, user_id: str, device_id: str -- cgit 1.5.1