diff options
author | Ben Banfield-Zanin <benbz@matrix.org> | 2021-03-01 10:06:09 +0000 |
---|---|---|
committer | Ben Banfield-Zanin <benbz@matrix.org> | 2021-03-01 10:06:09 +0000 |
commit | b26bee9faf957643cd34c4146b250b0009be205d (patch) | |
tree | a7a7e29f30acb437d010bdf6116c0f2729f21a1b /tests/rest | |
parent | Merge remote-tracking branch 'origin/release-v1.26.0' into toml/keycloak_hints (diff) | |
parent | Fixup changelog (diff) | |
download | synapse-toml/keycloak_hints.tar.xz |
Merge remote-tracking branch 'origin/release-v1.28.0' into toml/keycloak_hints github/toml/keycloak_hints toml/keycloak_hints
Diffstat (limited to 'tests/rest')
26 files changed, 1905 insertions, 579 deletions
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 9d22c04073..057e27372e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -130,8 +130,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): ) def _get_groups_user_is_in(self, access_token): - """Returns the list of groups the user is in (given their access token) - """ + """Returns the list of groups the user is in (given their access token)""" channel = self.make_request( "GET", "/joined_groups".encode("ascii"), access_token=access_token ) @@ -142,8 +141,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): class QuarantineMediaTestCase(unittest.HomeserverTestCase): - """Test /quarantine_media admin API. - """ + """Test /quarantine_media admin API.""" servlets = [ synapse.rest.admin.register_servlets, @@ -237,7 +235,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Attempt quarantine media APIs as non-admin url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345" channel = self.make_request( - "POST", url.encode("ascii"), access_token=non_admin_user_tok, + "POST", + url.encode("ascii"), + access_token=non_admin_user_tok, ) # Expect a forbidden error @@ -250,7 +250,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # And the roomID/userID endpoint url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine" channel = self.make_request( - "POST", url.encode("ascii"), access_token=non_admin_user_tok, + "POST", + url.encode("ascii"), + access_token=non_admin_user_tok, ) # Expect a forbidden error @@ -294,7 +296,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): urllib.parse.quote(server_name), urllib.parse.quote(media_id), ) - channel = self.make_request("POST", url, access_token=admin_user_tok,) + channel = self.make_request( + "POST", + url, + access_token=admin_user_tok, + ) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) @@ -346,7 +352,11 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote( room_id ) - channel = self.make_request("POST", url, access_token=admin_user_tok,) + channel = self.make_request( + "POST", + url, + access_token=admin_user_tok, + ) self.pump(1.0) self.assertEqual(200, int(channel.code), msg=channel.result["body"]) self.assertEqual( @@ -391,7 +401,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): non_admin_user ) channel = self.make_request( - "POST", url.encode("ascii"), access_token=admin_user_tok, + "POST", + url.encode("ascii"), + access_token=admin_user_tok, ) self.pump(1.0) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -437,7 +449,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): non_admin_user ) channel = self.make_request( - "POST", url.encode("ascii"), access_token=admin_user_tok, + "POST", + url.encode("ascii"), + access_token=admin_user_tok, ) self.pump(1.0) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 248c4442c3..2a1bcf1760 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -70,21 +70,27 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error is returned. """ channel = self.make_request( - "GET", self.url, access_token=self.other_user_token, + "GET", + self.url, + access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) channel = self.make_request( - "PUT", self.url, access_token=self.other_user_token, + "PUT", + self.url, + access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) channel = self.make_request( - "DELETE", self.url, access_token=self.other_user_token, + "DELETE", + self.url, + access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -99,17 +105,29 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): % self.other_user_device_id ) - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -123,17 +141,29 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): % self.other_user_device_id ) - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -146,16 +176,28 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.other_user ) - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - channel = self.make_request("PUT", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) # Delete unknown device returns status 200 self.assertEqual(200, channel.code, msg=channel.json_body) @@ -190,7 +232,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -207,12 +253,20 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) ) - channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) # Ensure the display name was not updated. - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new display", channel.json_body["display_name"]) @@ -233,7 +287,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) # Check new display_name - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("new displayname", channel.json_body["display_name"]) @@ -242,7 +300,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): """ Tests that a normal lookup for a device is successfully """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) @@ -264,7 +326,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): # Delete device channel = self.make_request( - "DELETE", self.url, access_token=self.admin_user_tok, + "DELETE", + self.url, + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -306,7 +370,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -316,7 +384,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/devices" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -327,7 +399,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -339,7 +415,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): """ # Get devices - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -355,7 +435,11 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.login("user", "pass") # Get devices - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_devices, channel.json_body["total"]) @@ -404,7 +488,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - channel = self.make_request("POST", self.url, access_token=other_user_token,) + channel = self.make_request( + "POST", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -414,7 +502,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" - channel = self.make_request("POST", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "POST", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -425,7 +517,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" - channel = self.make_request("POST", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "POST", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index d0090faa4f..e30ffe4fa0 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -51,19 +51,23 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # Two rooms and two users. Every user sends and reports every room event for i in range(5): self._create_event_and_report( - room_id=self.room_id1, user_tok=self.other_user_tok, + room_id=self.room_id1, + user_tok=self.other_user_tok, ) for i in range(5): self._create_event_and_report( - room_id=self.room_id2, user_tok=self.other_user_tok, + room_id=self.room_id2, + user_tok=self.other_user_tok, ) for i in range(5): self._create_event_and_report( - room_id=self.room_id1, user_tok=self.admin_user_tok, + room_id=self.room_id1, + user_tok=self.admin_user_tok, ) for i in range(5): self._create_event_and_report( - room_id=self.room_id2, user_tok=self.admin_user_tok, + room_id=self.room_id2, + user_tok=self.admin_user_tok, ) self.url = "/_synapse/admin/v1/event_reports" @@ -82,7 +86,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ - channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.other_user_tok, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -92,7 +100,11 @@ class EventReportsTestCase(unittest.HomeserverTestCase): Testing list of reported events """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -106,7 +118,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -121,7 +135,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "GET", self.url + "?from=5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -136,7 +152,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -213,7 +231,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # fetch the most recent first, largest timestamp channel = self.make_request( - "GET", self.url + "?dir=b", access_token=self.admin_user_tok, + "GET", + self.url + "?dir=b", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -229,7 +249,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # fetch the oldest first, smallest timestamp channel = self.make_request( - "GET", self.url + "?dir=f", access_token=self.admin_user_tok, + "GET", + self.url + "?dir=f", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -249,7 +271,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -262,7 +286,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -274,7 +300,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -288,7 +316,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries channel = self.make_request( - "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -299,7 +329,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries channel = self.make_request( - "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -310,7 +342,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries channel = self.make_request( - "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -322,7 +356,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase): # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear channel = self.make_request( - "GET", self.url + "?from=19", access_token=self.admin_user_tok, + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -331,8 +367,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) def _create_event_and_report(self, room_id, user_tok): - """Create and report events - """ + """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -345,8 +380,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): - """Checks that all attributes are present in an event report - """ + """Checks that all attributes are present in an event report""" for c in content: self.assertIn("id", c) self.assertIn("received_ts", c) @@ -381,7 +415,8 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok) self._create_event_and_report( - room_id=self.room_id1, user_tok=self.other_user_tok, + room_id=self.room_id1, + user_tok=self.other_user_tok, ) # first created event report gets `id`=2 @@ -401,7 +436,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ - channel = self.make_request("GET", self.url, access_token=self.other_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.other_user_tok, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -411,7 +450,11 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): Testing get a reported event """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self._check_fields(channel.json_body) @@ -479,8 +522,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.assertEqual("Event report not found", channel.json_body["error"]) def _create_event_and_report(self, room_id, user_tok): - """Create and report events - """ + """Create and report events""" resp = self.helper.send(room_id, tok=user_tok) event_id = resp["event_id"] @@ -493,8 +535,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): - """Checks that all attributes are present in a event report - """ + """Checks that all attributes are present in a event report""" self.assertIn("id", content) self.assertIn("received_ts", content) self.assertIn("room_id", content) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 51a7731693..31db472cd3 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -63,7 +63,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - channel = self.make_request("DELETE", url, access_token=self.other_user_token,) + channel = self.make_request( + "DELETE", + url, + access_token=self.other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -74,7 +78,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -85,7 +93,11 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only delete local media", channel.json_body["error"]) @@ -139,12 +151,17 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id) # Delete media - channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - media_id, channel.json_body["deleted_media"][0], + media_id, + channel.json_body["deleted_media"][0], ) # Attempt to access media @@ -207,7 +224,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.other_user_token = self.login("user", "pass") channel = self.make_request( - "POST", self.url, access_token=self.other_user_token, + "POST", + self.url, + access_token=self.other_user_token, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -220,7 +239,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" channel = self.make_request( - "POST", url + "?before_ts=1234", access_token=self.admin_user_tok, + "POST", + url + "?before_ts=1234", + access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) @@ -230,7 +251,11 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ If the parameter `before_ts` is missing, an error is returned. """ - channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) @@ -243,7 +268,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): If parameters are invalid, an error is returned. """ channel = self.make_request( - "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok, + "POST", + self.url + "?before_ts=-1234", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -304,7 +331,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - media_id, channel.json_body["deleted_media"][0], + media_id, + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) @@ -340,7 +368,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + server_and_media_id.split("/")[1], + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) @@ -374,7 +403,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + server_and_media_id.split("/")[1], + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) @@ -417,7 +447,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + server_and_media_id.split("/")[1], + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) @@ -461,7 +492,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) self.assertEqual( - server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + server_and_media_id.split("/")[1], + channel.json_body["deleted_media"][0], ) self._access_media(server_and_media_id, False) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index a0f32c5512..b55160b70a 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -127,8 +127,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): self._assert_peek(room_id, expect_code=403) def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ + """Assert that the admin user can (or cannot) peek into the room.""" url = "rooms/%s/initialSync" % (room_id,) channel = self.make_request( @@ -186,7 +185,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "POST", self.url, json.dumps({}), access_token=self.other_user_tok, + "POST", + self.url, + json.dumps({}), + access_token=self.other_user_tok, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -199,7 +201,10 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms/!unknown:test/delete" channel = self.make_request( - "POST", url, json.dumps({}), access_token=self.admin_user_tok, + "POST", + url, + json.dumps({}), + access_token=self.admin_user_tok, ) self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) @@ -212,12 +217,16 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms/invalidroom/delete" channel = self.make_request( - "POST", url, json.dumps({}), access_token=self.admin_user_tok, + "POST", + url, + json.dumps({}), + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - "invalidroom is not a legal room ID", channel.json_body["error"], + "invalidroom is not a legal room ID", + channel.json_body["error"], ) def test_new_room_user_does_not_exist(self): @@ -254,7 +263,8 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual( - "User must be our own: @not:exist.bla", channel.json_body["error"], + "User must be our own: @not:exist.bla", + channel.json_body["error"], ) def test_block_is_not_bool(self): @@ -491,8 +501,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._assert_peek(self.room_id, expect_code=403) def _is_blocked(self, room_id, expect=True): - """Assert that the room is blocked or not - """ + """Assert that the room is blocked or not""" d = self.store.is_room_blocked(room_id) if expect: self.assertTrue(self.get_success(d)) @@ -500,20 +509,17 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self.assertIsNone(self.get_success(d)) def _has_no_members(self, room_id): - """Assert there is now no longer anyone in the room - """ + """Assert there is now no longer anyone in the room""" users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([], users_in_room) def _is_member(self, room_id, user_id): - """Test that user is member of the room - """ + """Test that user is member of the room""" users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertIn(user_id, users_in_room) def _is_purged(self, room_id): - """Test that the following tables have been purged of all rows related to the room. - """ + """Test that the following tables have been purged of all rows related to the room.""" for table in PURGE_TABLES: count = self.get_success( self.store.db_pool.simple_select_one_onecol( @@ -527,8 +533,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room. - """ + """Assert that the admin user can (or cannot) peek into the room.""" url = "rooms/%s/initialSync" % (room_id,) channel = self.make_request( @@ -548,8 +553,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API. - """ + """Test /purge_room admin API.""" servlets = [ synapse.rest.admin.register_servlets, @@ -594,8 +598,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): class RoomTestCase(unittest.HomeserverTestCase): - """Test /room admin API. - """ + """Test /room admin API.""" servlets = [ synapse.rest.admin.register_servlets, @@ -623,7 +626,9 @@ class RoomTestCase(unittest.HomeserverTestCase): # Request the list of rooms url = "/_synapse/admin/v1/rooms" channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) # Check request completed successfully @@ -685,7 +690,10 @@ class RoomTestCase(unittest.HomeserverTestCase): # Set the name of the rooms so we get a consistent returned ordering for idx, room_id in enumerate(room_ids): self.helper.send_state( - room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok, + room_id, + "m.room.name", + {"name": str(idx)}, + tok=self.admin_user_tok, ) # Request the list of rooms @@ -704,7 +712,9 @@ class RoomTestCase(unittest.HomeserverTestCase): "name", ) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual( 200, int(channel.result["code"]), msg=channel.result["body"] @@ -744,7 +754,9 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -788,13 +800,18 @@ class RoomTestCase(unittest.HomeserverTestCase): # Set a name for the room self.helper.send_state( - room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok, + room_id, + "m.room.name", + {"name": test_room_name}, + tok=self.admin_user_tok, ) # Request the list of rooms url = "/_synapse/admin/v1/rooms" channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -860,7 +877,9 @@ class RoomTestCase(unittest.HomeserverTestCase): ) def _order_test( - order_type: str, expected_room_list: List[str], reverse: bool = False, + order_type: str, + expected_room_list: List[str], + reverse: bool = False, ): """Request the list of rooms in a certain order. Assert that order is what we expect @@ -875,7 +894,9 @@ class RoomTestCase(unittest.HomeserverTestCase): if reverse: url += "&dir=b" channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -907,13 +928,22 @@ class RoomTestCase(unittest.HomeserverTestCase): # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C self.helper.send_state( - room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok, + room_id_1, + "m.room.name", + {"name": "A"}, + tok=self.admin_user_tok, ) self.helper.send_state( - room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok, + room_id_2, + "m.room.name", + {"name": "B"}, + tok=self.admin_user_tok, ) self.helper.send_state( - room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok, + room_id_3, + "m.room.name", + {"name": "C"}, + tok=self.admin_user_tok, ) # Set room canonical room aliases @@ -990,10 +1020,16 @@ class RoomTestCase(unittest.HomeserverTestCase): # Set the name for each room self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + room_id_1, + "m.room.name", + {"name": room_name_1}, + tok=self.admin_user_tok, ) self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + room_id_2, + "m.room.name", + {"name": room_name_2}, + tok=self.admin_user_tok, ) def _search_test( @@ -1011,7 +1047,9 @@ class RoomTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) @@ -1071,15 +1109,23 @@ class RoomTestCase(unittest.HomeserverTestCase): # Set the name for each room self.helper.send_state( - room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + room_id_1, + "m.room.name", + {"name": room_name_1}, + tok=self.admin_user_tok, ) self.helper.send_state( - room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + room_id_2, + "m.room.name", + {"name": room_name_2}, + tok=self.admin_user_tok, ) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1109,7 +1155,9 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["joined_local_devices"]) @@ -1121,7 +1169,9 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(2, channel.json_body["joined_local_devices"]) @@ -1131,7 +1181,9 @@ class RoomTestCase(unittest.HomeserverTestCase): self.helper.leave(room_id_1, user_1, tok=user_tok_1) url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["joined_local_devices"]) @@ -1160,7 +1212,9 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1171,7 +1225,9 @@ class RoomTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -1180,6 +1236,23 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["total"], 3) + def test_room_state(self): + """Test that room state can be requested correctly""" + # Create two test rooms + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,) + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("state", channel.json_body) + # testing that the state events match is painful and not done here. We assume that + # the create_room already does the right thing, so no need to verify that we got + # the state events it created. + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): @@ -1327,7 +1400,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if user is a member of the room channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + "GET", + "/_matrix/client/r0/joined_rooms", + access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) @@ -1374,7 +1449,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if server admin is a member of the room channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, + "GET", + "/_matrix/client/r0/joined_rooms", + access_token=self.admin_user_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) @@ -1396,7 +1473,9 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if user is a member of the room channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + "GET", + "/_matrix/client/r0/joined_rooms", + access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) @@ -1425,11 +1504,97 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Validate if user is a member of the room channel = self.make_request( - "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, + "GET", + "/_matrix/client/r0/joined_rooms", + access_token=self.second_tok, ) self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) + def test_context_as_non_admin(self): + """ + Test that, without being admin, one cannot use the context admin API + """ + # Create a room. + user_id = self.register_user("test", "test") + user_tok = self.login("test", "test") + + self.register_user("test_2", "test") + user_tok_2 = self.login("test_2", "test") + + room_id = self.helper.create_room_as(user_id, tok=user_tok) + + # Populate the room with events. + events = [] + for i in range(30): + events.append( + self.helper.send_event( + room_id, "com.example.test", content={"index": i}, tok=user_tok + ) + ) + + # Now attempt to find the context using the admin API without being admin. + midway = (len(events) - 1) // 2 + for tok in [user_tok, user_tok_2]: + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/context/%s" + % (room_id, events[midway]["event_id"]), + access_token=tok, + ) + self.assertEquals( + 403, int(channel.result["code"]), msg=channel.result["body"] + ) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_context_as_admin(self): + """ + Test that, as admin, we can find the context of an event without having joined the room. + """ + + # Create a room. We're not part of it. + user_id = self.register_user("test", "test") + user_tok = self.login("test", "test") + room_id = self.helper.create_room_as(user_id, tok=user_tok) + + # Populate the room with events. + events = [] + for i in range(30): + events.append( + self.helper.send_event( + room_id, "com.example.test", content={"index": i}, tok=user_tok + ) + ) + + # Now let's fetch the context for this room. + midway = (len(events) - 1) // 2 + channel = self.make_request( + "GET", + "/_synapse/admin/v1/rooms/%s/context/%s" + % (room_id, events[midway]["event_id"]), + access_token=self.admin_user_tok, + ) + self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals( + channel.json_body["event"]["event_id"], events[midway]["event_id"] + ) + + for i, found_event in enumerate(channel.json_body["events_before"]): + for j, posted_event in enumerate(events): + if found_event["event_id"] == posted_event["event_id"]: + self.assertTrue(j < midway) + break + else: + self.fail("Event %s from events_before not found" % j) + + for i, found_event in enumerate(channel.json_body["events_after"]): + for j, posted_event in enumerate(events): + if found_event["event_id"] == posted_event["event_id"]: + self.assertTrue(j > midway) + break + else: + self.fail("Event %s from events_after not found" % j) + class MakeRoomAdminTestCase(unittest.HomeserverTestCase): servlets = [ @@ -1456,8 +1621,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): ) def test_public_room(self): - """Test that getting admin in a public room works. - """ + """Test that getting admin in a public room works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) @@ -1482,10 +1646,11 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): ) def test_private_room(self): - """Test that getting admin in a private room works and we get invited. - """ + """Test that getting admin in a private room works and we get invited.""" room_id = self.helper.create_room_as( - self.creator, tok=self.creator_tok, is_public=False, + self.creator, + tok=self.creator_tok, + is_public=False, ) channel = self.make_request( @@ -1509,8 +1674,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): ) def test_other_user(self): - """Test that giving admin in a public room works to a non-admin user works. - """ + """Test that giving admin in a public room works to a non-admin user works.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) @@ -1535,8 +1699,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): ) def test_not_enough_power(self): - """Test that we get a sensible error if there are no local room admins. - """ + """Test that we get a sensible error if there are no local room admins.""" room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index f48be3d65a..1f1d11f527 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -55,7 +55,10 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): If the user is not a server admin, an error 403 is returned. """ channel = self.make_request( - "GET", self.url, json.dumps({}), access_token=self.other_user_tok, + "GET", + self.url, + json.dumps({}), + access_token=self.other_user_tok, ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -67,7 +70,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): """ # unkown order_by channel = self.make_request( - "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok, + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -75,7 +80,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # negative from channel = self.make_request( - "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -83,7 +90,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # negative limit channel = self.make_request( - "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -91,7 +100,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # negative from_ts channel = self.make_request( - "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok, + "GET", + self.url + "?from_ts=-1234", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -99,7 +110,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # negative until_ts channel = self.make_request( - "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok, + "GET", + self.url + "?until_ts=-1234", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -117,7 +130,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # empty search term channel = self.make_request( - "GET", self.url + "?search_term=", access_token=self.admin_user_tok, + "GET", + self.url + "?search_term=", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -125,7 +140,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # invalid search order channel = self.make_request( - "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -138,7 +155,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self._create_users_with_media(10, 2) channel = self.make_request( - "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -154,7 +173,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self._create_users_with_media(20, 2) channel = self.make_request( - "GET", self.url + "?from=5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -170,7 +191,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self._create_users_with_media(20, 2) channel = self.make_request( - "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -190,7 +213,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries channel = self.make_request( - "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -201,7 +226,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries channel = self.make_request( - "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -212,7 +239,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries channel = self.make_request( - "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -223,7 +252,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # Set `from` to value of `next_token` for request remaining entries # Check `next_token` does not appear channel = self.make_request( - "GET", self.url + "?from=19", access_token=self.admin_user_tok, + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -237,7 +268,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): if users have no media created """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -264,10 +299,14 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # order by user_id self._order_test("user_id", ["@user_a:test", "@user_b:test", "@user_c:test"]) self._order_test( - "user_id", ["@user_a:test", "@user_b:test", "@user_c:test"], "f", + "user_id", + ["@user_a:test", "@user_b:test", "@user_c:test"], + "f", ) self._order_test( - "user_id", ["@user_c:test", "@user_b:test", "@user_a:test"], "b", + "user_id", + ["@user_c:test", "@user_b:test", "@user_a:test"], + "b", ) # order by displayname @@ -275,32 +314,46 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"] ) self._order_test( - "displayname", ["@user_c:test", "@user_b:test", "@user_a:test"], "f", + "displayname", + ["@user_c:test", "@user_b:test", "@user_a:test"], + "f", ) self._order_test( - "displayname", ["@user_a:test", "@user_b:test", "@user_c:test"], "b", + "displayname", + ["@user_a:test", "@user_b:test", "@user_c:test"], + "b", ) # order by media_length self._order_test( - "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], + "media_length", + ["@user_a:test", "@user_c:test", "@user_b:test"], ) self._order_test( - "media_length", ["@user_a:test", "@user_c:test", "@user_b:test"], "f", + "media_length", + ["@user_a:test", "@user_c:test", "@user_b:test"], + "f", ) self._order_test( - "media_length", ["@user_b:test", "@user_c:test", "@user_a:test"], "b", + "media_length", + ["@user_b:test", "@user_c:test", "@user_a:test"], + "b", ) # order by media_count self._order_test( - "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], + "media_count", + ["@user_a:test", "@user_c:test", "@user_b:test"], ) self._order_test( - "media_count", ["@user_a:test", "@user_c:test", "@user_b:test"], "f", + "media_count", + ["@user_a:test", "@user_c:test", "@user_b:test"], + "f", ) self._order_test( - "media_count", ["@user_b:test", "@user_c:test", "@user_a:test"], "b", + "media_count", + ["@user_b:test", "@user_c:test", "@user_a:test"], + "b", ) def test_from_until_ts(self): @@ -313,14 +366,20 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ts1 = self.clock.time_msec() # list all media when filter is not set - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 3) # filter media starting at `ts1` after creating first media # result is 0 channel = self.make_request( - "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok, + "GET", + self.url + "?from_ts=%s" % (ts1,), + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) @@ -342,7 +401,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # filter media until `ts2` and earlier channel = self.make_request( - "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok, + "GET", + self.url + "?until_ts=%s" % (ts2,), + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["users"][0]["media_count"], 6) @@ -351,7 +412,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self._create_users_with_media(20, 1) # check without filter get all users - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 20) @@ -376,7 +441,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): # filter and get empty result channel = self.make_request( - "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok, + "GET", + self.url + "?search_term=foobar", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(channel.json_body["total"], 0) @@ -441,7 +508,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): if dir is not None and dir in ("b", "f"): url += "&dir=%s" % (dir,) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], len(expected_user_list)) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 04599c2fcf..ba26895391 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync +from synapse.types import JsonDict from tests import unittest from tests.test_utils import make_awaitable @@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.user1 = self.register_user( - "user1", "pass1", admin=False, displayname="Name 1" - ) - self.user2 = self.register_user( - "user2", "pass2", admin=False, displayname="Name 2" - ) - def test_no_auth(self): """ Try to list users without authentication. @@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ + self._create_users(1) other_user_token = self.login("user1", "pass1") channel = self.make_request("GET", self.url, access_token=other_user_token) @@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ List all users, including deactivated users. """ + self._create_users(2) + channel = self.make_request( "GET", self.url + "?deactivated=true", @@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(3, channel.json_body["total"]) # Check that all fields are available - for u in channel.json_body["users"]: - self.assertIn("name", u) - self.assertIn("is_guest", u) - self.assertIn("admin", u) - self.assertIn("user_type", u) - self.assertIn("deactivated", u) - self.assertIn("displayname", u) - self.assertIn("avatar_url", u) + self._check_fields(channel.json_body["users"]) def test_search_term(self): """Test that searching for a users works correctly""" @@ -538,9 +528,14 @@ class UsersListTestCase(unittest.HomeserverTestCase): search_field: Field which is to request: `name` or `user_id` expected_http_code: The expected http code for the request """ - url = self.url + "?%s=%s" % (search_field, search_term,) + url = self.url + "?%s=%s" % ( + search_field, + search_term, + ) channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok, + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) @@ -549,6 +544,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): # Check that users were returned self.assertTrue("users" in channel.json_body) + self._check_fields(channel.json_body["users"]) users = channel.json_body["users"] # Check that the expected number of users were returned @@ -561,25 +557,30 @@ class UsersListTestCase(unittest.HomeserverTestCase): u = users[0] self.assertEqual(expected_user_id, u["name"]) + self._create_users(2) + + user1 = "@user1:test" + user2 = "@user2:test" + # Perform search tests - _search_test(self.user1, "er1") - _search_test(self.user1, "me 1") + _search_test(user1, "er1") + _search_test(user1, "me 1") - _search_test(self.user2, "er2") - _search_test(self.user2, "me 2") + _search_test(user2, "er2") + _search_test(user2, "me 2") - _search_test(self.user1, "er1", "user_id") - _search_test(self.user2, "er2", "user_id") + _search_test(user1, "er1", "user_id") + _search_test(user2, "er2", "user_id") # Test case insensitive - _search_test(self.user1, "ER1") - _search_test(self.user1, "NAME 1") + _search_test(user1, "ER1") + _search_test(user1, "NAME 1") - _search_test(self.user2, "ER2") - _search_test(self.user2, "NAME 2") + _search_test(user2, "ER2") + _search_test(user2, "NAME 2") - _search_test(self.user1, "ER1", "user_id") - _search_test(self.user2, "ER2", "user_id") + _search_test(user1, "ER1", "user_id") + _search_test(user2, "ER2", "user_id") _search_test(None, "foo") _search_test(None, "bar") @@ -587,6 +588,205 @@ class UsersListTestCase(unittest.HomeserverTestCase): _search_test(None, "foo", "user_id") _search_test(None, "bar", "user_id") + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + + # negative limit + channel = self.make_request( + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # negative from + channel = self.make_request( + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # invalid guests + channel = self.make_request( + "GET", + self.url + "?guests=not_bool", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid deactivated + channel = self.make_request( + "GET", + self.url + "?deactivated=not_bool", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + def test_limit(self): + """ + Testing list of users with limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 5) + self.assertEqual(channel.json_body["next_token"], "5") + self._check_fields(channel.json_body["users"]) + + def test_from(self): + """ + Testing list of users with a defined starting point (from) + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["users"]) + + def test_limit_and_from(self): + """ + Testing list of users with a defined starting point and limit + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + channel = self.make_request( + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(channel.json_body["next_token"], "15") + self.assertEqual(len(channel.json_body["users"]), 10) + self._check_fields(channel.json_body["users"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_users = 20 + # Create one less user (since there's already an admin user). + self._create_users(number_users - 1) + + # `next_token` does not appear + # Number of results is the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), number_users) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + channel = self.make_request( + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 19) + self.assertEqual(channel.json_body["next_token"], "19") + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + channel = self.make_request( + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_users) + self.assertEqual(len(channel.json_body["users"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + Args: + content: List that is checked for content + """ + for u in content: + self.assertIn("name", u) + self.assertIn("is_guest", u) + self.assertIn("admin", u) + self.assertIn("user_type", u) + self.assertIn("deactivated", u) + self.assertIn("shadow_banned", u) + self.assertIn("displayname", u) + self.assertIn("avatar_url", u) + + def _create_users(self, number_users: int): + """ + Create a number of users + Args: + number_users: Number of users to be created + """ + for i in range(1, number_users + 1): + self.register_user( + "user%d" % i, + "pass%d" % i, + admin=False, + displayname="Name %d" % i, + ) + class DeactivateAccountTestCase(unittest.HomeserverTestCase): @@ -639,7 +839,10 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( - "POST", url, access_token=self.other_user_token, content=b"{}", + "POST", + url, + access_token=self.other_user_token, + content=b"{}", ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -693,7 +896,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -717,7 +922,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -736,7 +943,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -760,7 +969,9 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -773,8 +984,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): self._is_erased("@user:test", False) def _is_erased(self, user_id: str, expect: bool) -> None: - """Assert that the user is erased or not - """ + """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) if expect: self.assertTrue(self.get_success(d)) @@ -808,13 +1018,20 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v2/users/@bob:test" - channel = self.make_request("GET", url, access_token=self.other_user_token,) + channel = self.make_request( + "GET", + url, + access_token=self.other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("You are not a server admin", channel.json_body["error"]) channel = self.make_request( - "PUT", url, access_token=self.other_user_token, content=b"{}", + "PUT", + url, + access_token=self.other_user_token, + content=b"{}", ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) @@ -867,7 +1084,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -912,7 +1133,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -922,6 +1147,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(False, channel.json_body["admin"]) self.assertEqual(False, channel.json_body["is_guest"]) self.assertEqual(False, channel.json_body["deactivated"]) + self.assertEqual(False, channel.json_body["shadow_banned"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) @override_config( @@ -1137,7 +1363,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1168,7 +1396,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1191,7 +1421,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1221,7 +1453,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1319,7 +1553,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1348,7 +1584,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_other_user, access_token=self.admin_user_tok, + "GET", + self.url_other_user, + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1377,7 +1615,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("bob", channel.json_body["displayname"]) # Get user - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1397,7 +1639,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) # Check user is not deactivated - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) @@ -1407,8 +1653,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["deactivated"]) def _is_erased(self, user_id, expect): - """Assert that the user is erased or not - """ + """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) if expect: self.assertTrue(self.get_success(d)) @@ -1448,7 +1693,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1458,7 +1707,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns an empty list """ url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1470,7 +1723,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1482,7 +1739,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): if user has no memberships """ # Get rooms - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1499,7 +1760,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.helper.create_room_as(self.other_user, tok=other_user_tok) # Get rooms - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) @@ -1542,7 +1807,11 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -1582,7 +1851,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1592,7 +1865,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1603,7 +1880,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1614,7 +1895,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ # Get pushers - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1641,7 +1926,11 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ) # Get pushers - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(1, channel.json_body["total"]) @@ -1690,7 +1979,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -1700,7 +1993,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/media" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) @@ -1711,7 +2008,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" - channel = self.make_request("GET", url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) @@ -1726,7 +2027,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self._create_media(other_user_tok, number_media) channel = self.make_request( - "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1745,7 +2048,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self._create_media(other_user_tok, number_media) channel = self.make_request( - "GET", self.url + "?from=5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1764,7 +2069,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self._create_media(other_user_tok, number_media) channel = self.make_request( - "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + "GET", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1779,7 +2086,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -1791,7 +2100,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( - "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + "GET", + self.url + "?from=-5", + access_token=self.admin_user_tok, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -1809,7 +2120,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of results is the number of entries channel = self.make_request( - "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=20", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1820,7 +2133,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # `next_token` does not appear # Number of max results is larger than the number of entries channel = self.make_request( - "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=21", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1831,7 +2146,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # `next_token` does appear # Number of max results is smaller than the number of entries channel = self.make_request( - "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + "GET", + self.url + "?limit=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1843,7 +2160,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Set `from` to value of `next_token` for request remaining entries # `next_token` does not appear channel = self.make_request( - "GET", self.url + "?from=19", access_token=self.admin_user_tok, + "GET", + self.url + "?from=19", + access_token=self.admin_user_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -1857,7 +2176,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): if user has no media created """ - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["total"]) @@ -1872,7 +2195,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") self._create_media(other_user_tok, number_media) - channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_media, channel.json_body["total"]) @@ -1899,8 +2226,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ) def _check_fields(self, content): - """Checks that all attributes are present in content - """ + """Checks that all attributes are present in content""" for m in content: self.assertIn("media_id", m) self.assertIn("media_type", m) @@ -1913,8 +2239,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): class UserTokenRestTestCase(unittest.HomeserverTestCase): - """Test for /_synapse/admin/v1/users/<user>/login - """ + """Test for /_synapse/admin/v1/users/<user>/login""" servlets = [ synapse.rest.admin.register_servlets, @@ -1945,16 +2270,14 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): return channel.json_body["access_token"] def test_no_auth(self): - """Try to login as a user without authentication. - """ + """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self): - """Try to login as a user as a non-admin user. - """ + """Try to login as a user as a non-admin user.""" channel = self.make_request( "POST", self.url, b"{}", access_token=self.other_user_tok ) @@ -1962,8 +2285,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) def test_send_event(self): - """Test that sending event as a user works. - """ + """Test that sending event as a user works.""" # Create a room. room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok) @@ -1977,8 +2299,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self.assertEqual(event.sender, self.other_user) def test_devices(self): - """Tests that logging in as a user doesn't create a new device for them. - """ + """Tests that logging in as a user doesn't create a new device for them.""" # Login in as the user self._get_token() @@ -1992,8 +2313,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["devices"]), 1) def test_logout(self): - """Test that calling `/logout` with the token works. - """ + """Test that calling `/logout` with the token works.""" # Login in as the user puppet_token = self._get_token() @@ -2083,8 +2403,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): } ) def test_consent(self): - """Test that sending a message is not subject to the privacy policies. - """ + """Test that sending a message is not subject to the privacy policies.""" # Have the admin user accept the terms. self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0")) @@ -2159,11 +2478,19 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.register_user("user2", "pass") other_user2_token = self.login("user2", "pass") - channel = self.make_request("GET", self.url1, access_token=other_user2_token,) + channel = self.make_request( + "GET", + self.url1, + access_token=other_user2_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - channel = self.make_request("GET", self.url2, access_token=other_user2_token,) + channel = self.make_request( + "GET", + self.url2, + access_token=other_user2_token, + ) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) @@ -2174,11 +2501,19 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain" - channel = self.make_request("GET", url1, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url1, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) - channel = self.make_request("GET", url2, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + url2, + access_token=self.admin_user_tok, + ) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only whois a local user", channel.json_body["error"]) @@ -2186,12 +2521,20 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ The lookup should succeed for an admin. """ - channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url1, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) - channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,) + channel = self.make_request( + "GET", + self.url2, + access_token=self.admin_user_tok, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) @@ -2202,12 +2545,84 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ other_user_token = self.login("user", "pass") - channel = self.make_request("GET", self.url1, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url1, + access_token=other_user_token, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) - channel = self.make_request("GET", self.url2, access_token=other_user_token,) + channel = self.make_request( + "GET", + self.url2, + access_token=other_user_token, + ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) + + +class ShadowBanRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + + self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to get information of an user without authentication. + """ + channel = self.make_request("POST", self.url) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request("POST", self.url, access_token=other_user_token) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that shadow-banning for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + self.assertEqual(400, channel.code, msg=channel.json_body) + + def test_success(self): + """ + Shadow-banning should succeed for an admin. + """ + # The user starts off as not shadow-banned. + other_user_token = self.login("user", "pass") + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertFalse(result.shadow_banned) + + channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + # Ensure the user is shadow-banned (and the cache was cleared). + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertTrue(result.shadow_banned) diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index 913ea3c98e..5256c11fe6 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -73,7 +73,9 @@ class PowerLevelsTestCase(HomeserverTestCase): # Mod the mod room_power_levels = self.helper.get_state( - self.room_id, "m.room.power_levels", tok=self.admin_access_token, + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, ) # Update existing power levels with mod at PL50 diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index f0707646bb..e0c74591b6 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -181,8 +181,7 @@ class RedactionsTestCase(HomeserverTestCase): ) def test_redact_event_as_moderator_ratelimit(self): - """Tests that the correct ratelimiting is applied to redactions - """ + """Tests that the correct ratelimiting is applied to redactions""" message_ids = [] # as a regular user, send messages to redact diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 31dc832fd5..aee99bb6a0 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -250,7 +250,8 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["backfill"]) self.hs = self.setup_test_homeserver( - config=config, federation_client=mock_federation_client, + config=config, + federation_client=mock_federation_client, ) return self.hs diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index e689c3fbea..d2cce44032 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -18,6 +18,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.types import UserID from tests import unittest @@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase): self.store = self.hs.get_datastore() self.get_success( - self.store.db_pool.simple_update( - table="users", - keyvalues={"name": self.banned_user_id}, - updatevalues={"shadow_banned": True}, - desc="shadow_ban", - ) + self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) ) self.other_user_id = self.register_user("otheruser", "pass") @@ -264,7 +260,10 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + self.banned_user_id, + room_id, + "m.room.member", + self.banned_user_id, ) ) self.assertEqual( @@ -296,7 +295,10 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + self.banned_user_id, + room_id, + "m.room.member", + self.banned_user_id, ) ) self.assertEqual( diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 0a5ca317ea..2ae896db1e 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -150,6 +150,8 @@ class GetEventsTestCase(unittest.HomeserverTestCase): event_id = resp["event_id"] channel = self.make_request( - "GET", "/events/" + event_id, access_token=self.token, + "GET", + "/events/" + event_id, + access_token=self.token, ) self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2672ce24c6..fb29eaed6f 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,7 +15,7 @@ import time import urllib.parse -from typing import Any, Dict, Union +from typing import Any, Dict, List, Union from urllib.parse import urlencode from mock import Mock @@ -29,8 +29,7 @@ from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet -from synapse.rest.synapse.client.pick_idp import PickIdpResource -from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import create_requester from tests import unittest @@ -75,6 +74,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"' # the query params in TEST_CLIENT_REDIRECT_URL EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')] +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -419,13 +422,61 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.sso"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(channel.json_body["flows"], expected_flows) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results = {} # type: Dict[str, Any] + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker @@ -442,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class + html = channel.result["body"].decode("utf-8") p = TestHtmlParser() - p.feed(channel.result["body"].decode("utf-8")) + p.feed(html) p.close() - self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"]) + # there should be a link for each href + returned_idps = [] # type: List[str] + for link in p.links: + path, query = link.split("?", 1) + self.assertEqual(path, "pick_idp") + params = urllib.parse.parse_qs(query) + self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL]) + returned_idps.append(params["idp"][0]) - self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) + self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) def test_multi_sso_redirect_to_cas(self): """If CAS is chosen, should redirect to the CAS server""" @@ -552,7 +611,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # matrix access token, mxid, and device id. login_token = params[2][1] chan = self.make_request( - "POST", "/login", content={"type": "m.login.token", "token": login_token}, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") @@ -560,9 +621,47 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_multi_sso_redirect_to_unknown(self): """An unknown IdP should cause a 400""" channel = self.make_request( - "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + "GET", + "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), ) self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: @@ -584,10 +683,12 @@ class CASTestCase(unittest.HomeserverTestCase): self.redirect_path = "_synapse/client/login/sso/redirect/confirm" config = self.default_config() + config["public_baseurl"] = ( + config.get("public_baseurl") or "https://matrix.goodserver.com:8448" + ) config["cas_config"] = { "enabled": True, "server_url": CAS_SERVER, - "service_url": "https://matrix.goodserver.com:8448", } cas_user_id = "username" @@ -621,7 +722,8 @@ class CASTestCase(unittest.HomeserverTestCase): mocked_http_client.get_raw.side_effect = get_raw self.hs = self.setup_test_homeserver( - config=config, proxied_http_client=mocked_http_client, + config=config, + proxied_http_client=mocked_http_client, ) return self.hs @@ -1119,11 +1221,8 @@ class UsernamePickerTestCase(HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d def test_username_picker(self): @@ -1137,7 +1236,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # that should redirect to the username picker self.assertEqual(channel.code, 302, channel.result) picker_url = channel.headers.getRawHeaders("Location")[0] - self.assertEqual(picker_url, "/_synapse/client/pick_username") + self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie cookies = {} # type: Dict[str,str] @@ -1149,7 +1248,9 @@ class UsernamePickerTestCase(HomeserverTestCase): # looks ok. username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions self.assertIn( - session_id, username_mapping_sessions, "session id not found in map", + session_id, + username_mapping_sessions, + "session id not found in map", ) session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") @@ -1161,12 +1262,11 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) # Now, submit a username to the username picker, which should serve a redirect - # back to the client - submit_path = picker_url + "/submit" + # to the completion page content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( "POST", - path=submit_path, + path=picker_url, content=content, content_is_form=True, custom_headers=[ @@ -1178,6 +1278,16 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location matches the requested redirect URL path, query = location_headers[0].split("?", 1) self.assertEqual(path, "https://x") @@ -1195,7 +1305,9 @@ class UsernamePickerTestCase(HomeserverTestCase): # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. chan = self.make_request( - "POST", "/login", content={"type": "m.login.token", "token": login_token}, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index e59fa70baa..f3448c94dd 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,163 +14,11 @@ # limitations under the License. """Tests REST events for /profile paths.""" -import json - -from mock import Mock - -from twisted.internet import defer - -import synapse.types -from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin from synapse.rest.client.v1 import login, profile, room from tests import unittest -from ....utils import MockHttpResource, setup_test_homeserver - -myid = "@1234ABCD:test" -PATH_PREFIX = "/_matrix/client/r0" - - -class MockHandlerProfileTestCase(unittest.TestCase): - """ Tests rest layer of profile management. - - Todo: move these into ProfileTestCase - """ - - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) - self.mock_handler = Mock( - spec=[ - "get_displayname", - "set_displayname", - "get_avatar_url", - "set_avatar_url", - "check_profile_query_allowed", - ] - ) - - self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) - self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) - self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) - self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) - self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( - Mock() - ) - - hs = yield setup_test_homeserver( - self.addCleanup, - "test", - federation_http_client=None, - resource_for_client=self.mock_resource, - federation=Mock(), - federation_client=Mock(), - profile_handler=self.mock_handler, - ) - - async def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) - - hs.get_auth().get_user_by_req = _get_user_by_req - - profile.register_servlets(hs, self.mock_resource) - - @defer.inlineCallbacks - def test_get_my_name(self): - mocked_get = self.mock_handler.get_displayname - mocked_get.return_value = defer.succeed("Frank") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/displayname" % (myid), None - ) - - self.assertEquals(200, code) - self.assertEquals({"displayname": "Frank"}, response) - self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD") - - @defer.inlineCallbacks - def test_set_my_name(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.return_value = defer.succeed(()) - - (code, response) = yield self.mock_resource.trigger( - "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}' - ) - - self.assertEquals(200, code) - self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.") - - @defer.inlineCallbacks - def test_set_my_name_noauth(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.side_effect = AuthError(400, "message") - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/displayname" % ("@4567:test"), - b'{"displayname": "Frank Jr."}', - ) - - self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code)) - - @defer.inlineCallbacks - def test_get_other_name(self): - mocked_get = self.mock_handler.get_displayname - mocked_get.return_value = defer.succeed("Bob") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/displayname" % ("@opaque:elsewhere"), None - ) - - self.assertEquals(200, code) - self.assertEquals({"displayname": "Bob"}, response) - - @defer.inlineCallbacks - def test_set_other_name(self): - mocked_set = self.mock_handler.set_displayname - mocked_set.side_effect = SynapseError(400, "message") - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/displayname" % ("@opaque:elsewhere"), - b'{"displayname":"bob"}', - ) - - self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code)) - - @defer.inlineCallbacks - def test_get_my_avatar(self): - mocked_get = self.mock_handler.get_avatar_url - mocked_get.return_value = defer.succeed("http://my.server/me.png") - - (code, response) = yield self.mock_resource.trigger( - "GET", "/profile/%s/avatar_url" % (myid), None - ) - - self.assertEquals(200, code) - self.assertEquals({"avatar_url": "http://my.server/me.png"}, response) - self.assertEquals(mocked_get.call_args[0][0].localpart, "1234ABCD") - - @defer.inlineCallbacks - def test_set_my_avatar(self): - mocked_set = self.mock_handler.set_avatar_url - mocked_set.return_value = defer.succeed(()) - - (code, response) = yield self.mock_resource.trigger( - "PUT", - "/profile/%s/avatar_url" % (myid), - b'{"avatar_url": "http://my.server/pic.gif"}', - ) - - self.assertEquals(200, code) - self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD") - self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") - class ProfileTestCase(unittest.HomeserverTestCase): @@ -187,37 +35,122 @@ class ProfileTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.owner = self.register_user("owner", "pass") self.owner_tok = self.login("owner", "pass") + self.other = self.register_user("other", "pass", displayname="Bob") + + def test_get_displayname(self): + res = self._get_displayname() + self.assertEqual(res, "owner") def test_set_displayname(self): channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), - content=json.dumps({"displayname": "test"}), + content={"displayname": "test"}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 200, channel.result) - res = self.get_displayname() + res = self._get_displayname() self.assertEqual(res, "test") + def test_set_displayname_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + ) + self.assertEqual(channel.code, 401, channel.result) + def test_set_displayname_too_long(self): """Attempts to set a stupid displayname should get a 400""" channel = self.make_request( "PUT", "/profile/%s/displayname" % (self.owner,), - content=json.dumps({"displayname": "test" * 100}), + content={"displayname": "test" * 100}, access_token=self.owner_tok, ) self.assertEqual(channel.code, 400, channel.result) - res = self.get_displayname() + res = self._get_displayname() self.assertEqual(res, "owner") - def get_displayname(self): - channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,)) + def test_get_displayname_other(self): + res = self._get_displayname(self.other) + self.assertEquals(res, "Bob") + + def test_set_displayname_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.other,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_get_avatar_url(self): + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_set_avatar_url(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_avatar_url() + self.assertEqual(res, "http://my.server/pic.gif") + + def test_set_avatar_url_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_avatar_url_too_long(self): + """Attempts to set a stupid avatar_url should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_get_avatar_url_other(self): + res = self._get_avatar_url(self.other) + self.assertIsNone(res) + + def test_set_avatar_url_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.other,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def _get_displayname(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/displayname" % (name or self.owner,) + ) self.assertEqual(channel.code, 200, channel.result) return channel.json_body["displayname"] + def _get_avatar_url(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/avatar_url" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body.get("avatar_url") + class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index d4e3165436..ed65f645fc 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -46,7 +46,9 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", federation_http_client=None, federation_client=Mock(), + "red", + federation_http_client=None, + federation_client=Mock(), ) self.hs.get_federation_handler = Mock() @@ -616,6 +618,41 @@ class RoomMemberStateTestCase(RoomBase): self.assertEquals(json.loads(content), channel.json_body) +class RoomInviteRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_rooms_ratelimit(self): + """Tests that invites in a room are actually rate-limited.""" + room_id = self.helper.create_room_as(self.user_id) + + for i in range(3): + self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) + + self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_users_ratelimit(self): + """Tests that invites to a specific user are actually rate-limited.""" + + for i in range(3): + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red") + + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) + + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" @@ -1445,7 +1482,9 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), 2, [result["result"]["content"] for result in results], + len(results), + 2, + [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1480,7 +1519,9 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), 4, [result["result"]["content"] for result in results], + len(results), + 4, + [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], @@ -1527,7 +1568,9 @@ class LabelsTestCase(unittest.HomeserverTestCase): results = channel.json_body["search_categories"]["room_events"]["results"] self.assertEqual( - len(results), 1, [result["result"]["content"] for result in results], + len(results), + 1, + [result["result"]["content"] for result in results], ) self.assertEqual( results[0]["result"]["content"]["body"], diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 38c51525a3..329dbd06de 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -18,8 +18,6 @@ from mock import Mock -from twisted.internet import defer - from synapse.rest.client.v1 import room from synapse.types import UserID @@ -39,7 +37,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", federation_http_client=None, federation_client=Mock(), + "red", + federation_http_client=None, + federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] @@ -60,32 +60,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_datastore().insert_client_ip = _insert_client_ip - def get_room_members(room_id): - if room_id == self.room_id: - return defer.succeed([self.user]) - else: - return defer.succeed([]) - - @defer.inlineCallbacks - def fetch_room_distributions_into( - room_id, localusers=None, remotedomains=None, ignore_user=None - ): - members = yield get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if hs.is_mine(member): - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - - hs.get_room_member_handler().fetch_room_distributions_into = ( - fetch_room_distributions_into - ) - return hs def prepare(self, reactor, clock, hs): diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index b1333df82d..8231a423f3 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -166,9 +166,12 @@ class RestHelper: json.dumps(data).encode("utf8"), ) - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], ) self.auth_user_id = temp_id @@ -201,9 +204,12 @@ class RestHelper: json.dumps(content).encode("utf8"), ) - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], ) return channel.json_body @@ -251,9 +257,12 @@ class RestHelper: channel = make_request(self.hs.get_reactor(), self.site, method, path, content) - assert int(channel.result["code"]) == expect_code, ( - "Expected: %d, got: %d, resp: %r" - % (expect_code, int(channel.result["code"]), channel.result["body"]) + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], ) return channel.json_body @@ -447,7 +456,10 @@ class RestHelper: return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) def complete_oidc_auth( - self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + self, + oauth_uri: str, + cookies: Mapping[str, str], + user_info_dict: JsonDict, ) -> FakeChannel: """Mock out an OIDC authentication flow @@ -491,7 +503,9 @@ class RestHelper: (expected_uri, resp_obj) = expected_requests.pop(0) assert uri == expected_uri resp = FakeResponse( - code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), + code=200, + phrase=b"OK", + body=json.dumps(resp_obj).encode("utf-8"), ) return resp diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index cb87b80e33..e72b61963d 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,7 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes +from synapse.api.errors import Codes, HttpResponseException from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource @@ -75,8 +75,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): - """Test basic password reset flow - """ + """Test basic password reset flow""" old_password = "monkey" new_password = "kangeroo" @@ -112,6 +111,55 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_email(self): + """Test that we ratelimit /requestToken for the same email.""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test1@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + def reset(ip): + client_secret = "foobar" + session_id = self._request_token(email, client_secret, ip) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + self.email_attempts.clear() + + # We expect to be able to make three requests before getting rate + # limited. + # + # We change IPs to ensure that we're not being ratelimited due to the + # same IP + reset("127.0.0.1") + reset("127.0.0.2") + reset("127.0.0.3") + + with self.assertRaises(HttpResponseException) as cm: + reset("127.0.0.4") + + self.assertEqual(cm.exception.code, 429) + def test_basic_password_reset_canonicalise_email(self): """Test basic password reset flow Request password reset with different spelling @@ -153,8 +201,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.attempt_wrong_password_login("kermit", old_password) def test_cant_reset_password_without_clicking_link(self): - """Test that we do actually need to click the link in the email - """ + """Test that we do actually need to click the link in the email""" old_password = "monkey" new_password = "kangeroo" @@ -239,13 +286,20 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) - def _request_token(self, email, client_secret): + def _request_token(self, email, client_secret, ip="127.0.0.1"): channel = self.make_request( "POST", b"account/password/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, + client_ip=ip, ) - self.assertEquals(200, channel.code, channel.result) + + if channel.code != 200: + raise HttpResponseException( + channel.code, + channel.result["reason"], + channel.result["body"], + ) return channel.json_body["sid"] @@ -509,9 +563,22 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_address_trim(self): self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_ip(self): + """Tests that adding emails is ratelimited by IP""" + + # We expect to be able to set three emails before getting ratelimited. + self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) + self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) + self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + + with self.assertRaises(HttpResponseException) as cm: + self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + + self.assertEqual(cm.exception.code, 429) + def test_add_email_if_disabled(self): - """Test adding email to profile when doing so is disallowed - """ + """Test adding email to profile when doing so is disallowed""" self.hs.config.enable_3pid_changes = False client_secret = "foobar" @@ -541,15 +608,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email(self): - """Test deleting an email from profile - """ + """Test deleting an email from profile""" # Add a threepid self.get_success( self.store.user_add_threepid( @@ -571,15 +639,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email_if_disabled(self): - """Test deleting an email from profile when disallowed - """ + """Test deleting an email from profile when disallowed""" self.hs.config.enable_3pid_changes = False # Add a threepid @@ -605,7 +674,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -613,8 +684,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) def test_cant_add_email_without_clicking_link(self): - """Test that we do actually need to click the link in the email - """ + """Test that we do actually need to click the link in the email""" client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -640,7 +710,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -673,7 +745,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) @@ -718,7 +792,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Ensure not providing a next_link parameter still works self._request_token( - "something@example.com", "some_secret", next_link=None, expect_code=200, + "something@example.com", + "some_secret", + next_link=None, + expect_code=200, ) self._request_token( @@ -776,13 +853,27 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): if next_link: body["next_link"] = next_link - channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) - self.assertEquals(expect_code, channel.code, channel.result) + channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + body, + ) + + if channel.code != expect_code: + raise HttpResponseException( + channel.code, + channel.result["reason"], + channel.result["body"], + ) return channel.json_body.get("sid") def _request_token_invalid_email( - self, email, expected_errcode, expected_error, client_secret="foobar", + self, + email, + expected_errcode, + expected_error, + client_secret="foobar", ): channel = self.make_request( "POST", @@ -821,12 +912,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): return match.group(0) def _add_email(self, request_email, expected_email): - """Test adding an email to profile - """ + """Test adding an email to profile""" + previous_email_attempts = len(self.email_attempts) + client_secret = "foobar" session_id = self._request_token(request_email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) link = self._get_link_from_email() self._validate_token(link) @@ -850,9 +942,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): # Get user channel = self.make_request( - "GET", self.url_3pid, access_token=self.user_id_tok, + "GET", + self.url_3pid, + access_token=self.user_id_tok, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"]) + + threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} + self.assertIn(expected_email, threepids) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index a6488a3d29..c26ad824f7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -22,7 +22,7 @@ from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register -from synapse.rest.oidc import OIDCResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID from tests import unittest @@ -102,7 +102,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase): """Ensure that fallback auth via a captcha works.""" # Returns a 401 as per the spec channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"}, + 401, + {"username": "user", "type": "m.login.password", "password": "bar"}, ) # Grab the session @@ -173,9 +174,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def create_resource_dict(self): resource_dict = super().create_resource_dict() - if HAS_OIDC: - # mount the OIDC resource at /_synapse/oidc - resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) + resource_dict.update(build_synapse_client_resource_tree(self.hs)) return resource_dict def prepare(self, reactor, clock, hs): @@ -193,7 +192,10 @@ class UIAuthTests(unittest.HomeserverTestCase): ) -> FakeChannel: """Delete an individual device.""" channel = self.make_request( - "DELETE", "devices/" + device, body, access_token=access_token, + "DELETE", + "devices/" + device, + body, + access_token=access_token, ) # Ensure the response is sane. @@ -206,7 +208,10 @@ class UIAuthTests(unittest.HomeserverTestCase): # Note that this uses the delete_devices endpoint so that we can modify # the payload half-way through some tests. channel = self.make_request( - "POST", "delete_devices", body, access_token=self.user_tok, + "POST", + "delete_devices", + body, + access_token=self.user_tok, ) # Ensure the response is sane. @@ -338,7 +343,7 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) - @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}}) + @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) def test_can_reuse_session(self): """ The session can be reused if configured. @@ -419,7 +424,10 @@ class UIAuthTests(unittest.HomeserverTestCase): # and now the delete request should succeed. self.delete_device( - self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}}, + self.user_tok, + self.device_id, + 200, + body={"auth": {"session": session_id}}, ) @skip_unless(HAS_OIDC, "requires OIDC") @@ -445,8 +453,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_offers_both_flows_for_upgraded_user(self): - """A user that had a password and then logged in with SSO should get both flows - """ + """A user that had a password and then logged in with SSO should get both flows""" login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) @@ -461,8 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase): @skip_unless(HAS_OIDC, "requires OIDC") @override_config({"oidc_config": TEST_OIDC_CONFIG}) def test_ui_auth_fails_for_incorrect_sso_user(self): - """If the user tries to authenticate with the wrong SSO user, they get an error - """ + """If the user tries to authenticate with the wrong SSO user, they get an error""" # log the user in login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index fba34def30..5ebc5707a5 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -91,7 +91,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_TOO_SHORT, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_TOO_SHORT, + channel.result, ) def test_password_no_digit(self): @@ -100,7 +102,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_DIGIT, + channel.result, ) def test_password_no_symbol(self): @@ -109,7 +113,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_SYMBOL, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_SYMBOL, + channel.result, ) def test_password_no_uppercase(self): @@ -118,7 +124,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_UPPERCASE, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_UPPERCASE, + channel.result, ) def test_password_no_lowercase(self): @@ -127,7 +135,9 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) self.assertEqual( - channel.json_body["errcode"], Codes.PASSWORD_NO_LOWERCASE, channel.result, + channel.json_body["errcode"], + Codes.PASSWORD_NO_LOWERCASE, + channel.result, ) def test_password_compliant(self): diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index bd574077e7..7c457754f1 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -83,14 +83,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) def test_deny_membership(self): - """Test that we deny relations on membership events - """ + """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) def test_deny_double_react(self): - """Test that we deny relations on membership events - """ + """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) @@ -98,8 +96,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): - """Tests that calling pagination API correctly the latest relations. - """ + """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) @@ -174,8 +171,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(found_event_ids, expected_event_ids) def test_aggregation_pagination_groups(self): - """Test that we can paginate annotation groups correctly. - """ + """Test that we can paginate annotation groups correctly.""" # We need to create ten separate users to send each reaction. access_tokens = [self.user_token, self.user2_token] @@ -240,8 +236,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(sent_groups, found_groups) def test_aggregation_pagination_within_group(self): - """Test that we can paginate within an annotation group. - """ + """Test that we can paginate within an annotation group.""" # We need to create ten separate users to send each reaction. access_tokens = [self.user_token, self.user2_token] @@ -311,8 +306,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(found_event_ids, expected_event_ids) def test_aggregation(self): - """Test that annotations get correctly aggregated. - """ + """Test that annotations get correctly aggregated.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -344,8 +338,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) def test_aggregation_redactions(self): - """Test that annotations get correctly aggregated after a redaction. - """ + """Test that annotations get correctly aggregated after a redaction.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -379,8 +372,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) def test_aggregation_must_be_annotation(self): - """Test that aggregations must be annotations. - """ + """Test that aggregations must be annotations.""" channel = self.make_request( "GET", @@ -437,8 +429,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) def test_edit(self): - """Test that a simple edit works. - """ + """Test that a simple edit works.""" new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 512e36c236..2dbf42397a 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -388,13 +388,19 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Check that room name changes increase the unread counter. self.helper.send_state( - self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + self.room_id, + "m.room.name", + {"name": "my super room"}, + tok=self.tok2, ) self._check_unread_count(1) # Check that room topic changes increase the unread counter. self.helper.send_state( - self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + self.room_id, + "m.room.topic", + {"topic": "welcome!!!"}, + tok=self.tok2, ) self._check_unread_count(2) @@ -404,7 +410,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Check that custom events with a body increase the unread counter. self.helper.send_event( - self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + self.room_id, + "org.matrix.custom_type", + {"body": "hello"}, + tok=self.tok2, ) self._check_unread_count(4) @@ -443,14 +452,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): """Syncs and compares the unread count with the expected value.""" channel = self.make_request( - "GET", self.url % self.next_batch, access_token=self.tok, + "GET", + self.url % self.next_batch, + access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) room_entry = channel.json_body["rooms"]["join"][self.room_id] self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + room_entry["org.matrix.msc2654.unread_count"], + expected_count, + room_entry, ) # Store the next batch for the next request. diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py new file mode 100644 index 0000000000..d890d11863 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_upgrade_room.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet + +from tests import unittest +from tests.server import FakeChannel + + +class UpgradeRoomTest(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + self.creator = self.register_user("creator", "pass") + self.creator_token = self.login(self.creator, "pass") + + self.other = self.register_user("user", "pass") + self.other_token = self.login(self.other, "pass") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) + self.helper.join(self.room_id, self.other, tok=self.other_token) + + def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: + # We never want a cached response. + self.reactor.advance(5 * 60 + 1) + + return self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, + # This will upgrade a room to the same version, but that's fine. + content={"new_version": DEFAULT_ROOM_VERSION}, + access_token=token or self.creator_token, + ) + + def test_upgrade(self): + """ + Upgrading a room should work fine. + """ + channel = self._upgrade_room() + self.assertEquals(200, channel.code, channel.result) + self.assertIn("replacement_room", channel.json_body) + + def test_not_in_room(self): + """ + Upgrading a room should work fine. + """ + # THe user isn't in the room. + roomless = self.register_user("roomless", "pass") + roomless_token = self.login(roomless, "pass") + + channel = self._upgrade_room(roomless_token) + self.assertEquals(403, channel.code, channel.result) + + def test_power_levels(self): + """ + Another user can upgrade the room if their power level is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users"][self.other] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_user_default(self): + """ + Another user can upgrade the room if the default power level for users is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users_default"] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_tombstone(self): + """ + Another user can upgrade the room if they can send the tombstone event. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["events"]["m.room.tombstone"] = 0 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + self.assertNotIn(self.other, power_levels["users"]) diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 5e90d656f7..9d0d0ef414 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -180,7 +180,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): async def post_json(destination, path, data): self.assertEqual(destination, self.hs.hostname) self.assertEqual( - path, "/_matrix/key/v2/query", + path, + "/_matrix/key/v2/query", ) channel = FakeChannel(self.site, self.reactor) @@ -188,7 +189,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): req.content = BytesIO(encode_canonical_json(data)) req.requestReceived( - b"POST", path.encode("utf-8"), b"1.1", + b"POST", + path.encode("utf-8"), + b"1.1", ) channel.await_result() self.assertEqual(channel.code, 200) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index ae2b32b131..0789b12392 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -30,6 +30,8 @@ from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.logging.context import make_deferred_yieldable +from synapse.rest import admin +from synapse.rest.client.v1 import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage @@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from tests import unittest from tests.server import FakeSite, make_request +from tests.utils import default_config class MediaStorageTests(unittest.HomeserverTestCase): @@ -164,7 +167,16 @@ class _TestImage: ), ), # an empty file - (_TestImage(b"", b"image/gif", b".gif", None, None, False,),), + ( + _TestImage( + b"", + b"image/gif", + b".gif", + None, + None, + False, + ), + ), ], ) class MediaRepoTests(unittest.HomeserverTestCase): @@ -202,7 +214,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): config = self.default_config() config["media_store_path"] = self.media_store_path - config["thumbnail_requirements"] = {} config["max_image_pixels"] = 2000000 provider_config = { @@ -313,15 +324,39 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) def test_thumbnail_crop(self): + """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( "crop", self.test_image.expected_cropped, self.test_image.expected_found ) def test_thumbnail_scale(self): + """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( "scale", self.test_image.expected_scaled, self.test_image.expected_found ) + def test_invalid_type(self): + """An invalid thumbnail type is never available.""" + self._test_thumbnail("invalid", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} + ) + def test_no_thumbnail_crop(self): + """ + Override the config to generate only scaled thumbnails, but request a cropped one. + """ + self._test_thumbnail("crop", None, False) + + @unittest.override_config( + {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} + ) + def test_no_thumbnail_scale(self): + """ + Override the config to generate only cropped thumbnails, but request a scaled one. + """ + self._test_thumbnail("scale", None, False) + def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method channel = make_request( @@ -375,3 +410,93 @@ class MediaRepoTests(unittest.HomeserverTestCase): headers.getRawHeaders(b"X-Robots-Tag"), [b"noindex, nofollow, noarchive, noimageindex"], ) + + +class TestSpamChecker: + """A spam checker module that rejects all media that includes the bytes + `evil`. + """ + + def __init__(self, config, api): + self.config = config + self.api = api + + def parse_config(config): + return config + + async def check_event_for_spam(self, foo): + return False # allow all events + + async def user_may_invite(self, inviter_userid, invitee_userid, room_id): + return True # allow all invites + + async def user_may_create_room(self, userid): + return True # allow all room creations + + async def user_may_create_room_alias(self, userid, room_alias): + return True # allow all room aliases + + async def user_may_publish_room(self, userid, room_id): + return True # allow publishing of all rooms + + async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool: + buf = BytesIO() + await file_wrapper.write_chunks_to(buf.write) + + return b"evil" in buf.getvalue() + + +class SpamCheckerTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + admin.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + # Allow for uploading and downloading to/from the media repo + self.media_repo = hs.get_media_repository_resource() + self.download_resource = self.media_repo.children[b"download"] + self.upload_resource = self.media_repo.children[b"upload"] + + def default_config(self): + config = default_config("test") + + config.update( + { + "spam_checker": [ + { + "module": TestSpamChecker.__module__ + ".TestSpamChecker", + "config": {}, + } + ] + } + ) + + return config + + def test_upload_innocent(self): + """Attempt to upload some innocent data that should be allowed.""" + + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + self.helper.upload_media( + self.upload_resource, image_data, tok=self.tok, expect_code=200 + ) + + def test_upload_ban(self): + """Attempt to upload some data that includes bytes "evil", which should + get rejected by the spam checker. + """ + + data = b"Some evil data" + + self.helper.upload_media( + self.upload_resource, data, tok=self.tok, expect_code=400 + ) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index c5e44af9f7..14de0921be 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,3 +40,12 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) + + def test_well_known_no_public_baseurl(self): + self.hs.config.public_baseurl = None + + channel = self.make_request( + "GET", "/.well-known/matrix/client", shorthand=False + ) + + self.assertEqual(channel.code, 404) |