diff options
Diffstat (limited to 'tests/rest')
-rw-r--r-- | tests/rest/admin/test_admin.py | 48 | ||||
-rw-r--r-- | tests/rest/admin/test_background_updates.py | 154 | ||||
-rw-r--r-- | tests/rest/admin/test_room.py | 982 | ||||
-rw-r--r-- | tests/rest/admin/test_user.py | 30 | ||||
-rw-r--r-- | tests/rest/client/test_capabilities.py | 8 | ||||
-rw-r--r-- | tests/rest/client/test_directory.py | 105 | ||||
-rw-r--r-- | tests/rest/client/test_login.py | 73 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 173 | ||||
-rw-r--r-- | tests/rest/client/test_rooms.py | 154 | ||||
-rw-r--r-- | tests/rest/client/utils.py | 71 |
10 files changed, 1635 insertions, 163 deletions
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 192073c520..af849bd471 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -474,3 +474,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): % server_and_media_id_2 ), ) + + +class PurgeHistoryTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v1/purge_history/{self.room_id}" + self.url_status = "/_synapse/admin/v1/purge_history_status/" + + def test_purge_history(self): + """ + Simple test of purge history API. + Test only that is is possible to call, get status 200 and purge_id. + """ + + channel = self.make_request( + "POST", + self.url, + content={"delete_local_events": True, "purge_up_to_ts": 0}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("purge_id", channel.json_body) + purge_id = channel.json_body["purge_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status + purge_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("complete", channel.json_body["status"]) diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index 78c48db552..1786316763 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -11,8 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus +from typing import Collection + +from parameterized import parameterized import synapse.rest.admin +from synapse.api.errors import Codes from synapse.rest.client import login from synapse.server import HomeServer @@ -30,6 +35,60 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") + @parameterized.expand( + [ + ("GET", "/_synapse/admin/v1/background_updates/enabled"), + ("POST", "/_synapse/admin/v1/background_updates/enabled"), + ("GET", "/_synapse/admin/v1/background_updates/status"), + ("POST", "/_synapse/admin/v1/background_updates/start_job"), + ] + ) + def test_requester_is_no_admin(self, method: str, url: str): + """ + If the user is not a server admin, an error 403 is returned. + """ + + self.register_user("user", "pass", admin=False) + other_user_tok = self.login("user", "pass") + + channel = self.make_request( + method, + url, + content={}, + access_token=other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + url = "/_synapse/admin/v1/background_updates/start_job" + + # empty content + channel = self.make_request( + "POST", + url, + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # job_name invalid + channel = self.make_request( + "POST", + url, + content={"job_name": "unknown"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + def _register_bg_update(self): "Adds a bg update but doesn't start it" @@ -60,7 +119,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Background updates should be enabled, but none should be running. self.assertDictEqual( @@ -82,7 +141,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Background updates should be enabled, and one should be running. self.assertDictEqual( @@ -114,7 +173,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/enabled", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) # Disable the BG updates @@ -124,7 +183,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": False}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": False}) # Advance a bit and get the current status, note this will finish the in @@ -137,7 +196,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual( channel.json_body, { @@ -162,7 +221,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # There should be no change from the previous /status response. self.assertDictEqual( @@ -188,7 +247,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): content={"enabled": True}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertDictEqual(channel.json_body, {"enabled": True}) @@ -199,7 +258,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "/_synapse/admin/v1/background_updates/status", access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) # Background updates should be enabled and making progress. self.assertDictEqual( @@ -216,3 +275,82 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase): "enabled": True, }, ) + + @parameterized.expand( + [ + ("populate_stats_process_rooms", ["populate_stats_process_rooms"]), + ( + "regenerate_directory", + [ + "populate_user_directory_createtables", + "populate_user_directory_process_rooms", + "populate_user_directory_process_users", + "populate_user_directory_cleanup", + ], + ), + ] + ) + def test_start_backround_job(self, job_name: str, updates: Collection[str]): + """ + Test that background updates add to database and be processed. + + Args: + job_name: name of the job to call with API + updates: collection of background updates to be started + """ + + # no background update is waiting + self.assertTrue( + self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/background_updates/start_job", + content={"job_name": job_name}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # test that each background update is waiting now + for update in updates: + self.assertFalse( + self.get_success( + self.store.db_pool.updates.has_completed_background_update(update) + ) + ) + + self.wait_for_background_updates() + + # background updates are done + self.assertTrue( + self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ) + ) + + def test_start_backround_job_twice(self): + """Test that add a background update twice return an error.""" + + # add job to database + self.get_success( + self.store.db_pool.simple_insert( + table="background_updates", + values={ + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + }, + ) + ) + + channel = self.make_request( + "POST", + "/_synapse/admin/v1/background_updates/start_job", + content={"job_name": "populate_stats_process_rooms"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 46116644ce..07077aff78 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -14,12 +14,16 @@ import json import urllib.parse +from http import HTTPStatus from typing import List, Optional from unittest.mock import Mock +from parameterized import parameterized + import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes +from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room from tests import unittest @@ -68,11 +72,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - json.dumps({}), + {}, access_token=self.other_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_room_does_not_exist(self): @@ -84,11 +88,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_room_is_not_valid(self): @@ -100,11 +104,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", url, - json.dumps({}), + {}, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom is not a legal room ID", channel.json_body["error"], @@ -119,11 +123,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertIn("new_room_id", channel.json_body) self.assertIn("kicked_users", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -138,11 +142,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "User must be our own: @not:exist.bla", channel.json_body["error"], @@ -157,11 +161,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_is_not_bool(self): @@ -173,11 +177,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) def test_purge_room_and_block(self): @@ -199,11 +203,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -232,11 +236,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -266,11 +270,11 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "DELETE", self.url.encode("ascii"), - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(None, channel.json_body["new_room_id"]) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("failed_to_kick_users", channel.json_body) @@ -281,6 +285,31 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._is_blocked(self.room_id, expect=True) self._has_no_members(self.room_id) + @parameterized.expand([(True,), (False,)]) + def test_block_unknown_room(self, purge: bool) -> None: + """ + We can block an unknown room. In this case, the `purge` argument + should be ignored. + """ + room_id = "!unknown:test" + + # The room isn't already in the blocked rooms table + self._is_blocked(room_id, expect=False) + + # Request the room be blocked. + channel = self.make_request( + "DELETE", + f"/_synapse/admin/v1/rooms/{room_id}", + {"block": True, "purge": purge}, + access_token=self.admin_user_tok, + ) + + # The room is now blocked. + self.assertEqual( + HTTPStatus.OK, int(channel.result["code"]), msg=channel.result["body"] + ) + self._is_blocked(room_id) + def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not yet accepted the privacy policy. This used to fail when we tried to @@ -316,7 +345,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -345,7 +374,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): json.dumps({"history_visibility": "world_readable"}), access_token=self.other_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Test that room is not purged with self.assertRaises(AssertionError): @@ -362,7 +391,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) self.assertIn("new_room_id", channel.json_body) self.assertIn("failed_to_kick_users", channel.json_body) @@ -418,17 +447,616 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + url = "events?timeout=0&room_id=" + room_id + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + +class DeleteRoomV2TestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + events.register_servlets, + room.register_servlets, + room.register_deprecated_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.event_creation_handler = hs.get_event_creation_handler() + hs.config.consent.user_consent_version = "1" + + consent_uri_builder = Mock() + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" + self.event_creation_handler._consent_uri_builder = consent_uri_builder + + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # Mark the admin user as having consented + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = f"/_synapse/admin/v2/rooms/{self.room_id}" + self.url_status_by_room_id = ( + f"/_synapse/admin/v2/rooms/{self.room_id}/delete_status" + ) + self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/" + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_requester_is_no_admin(self, method: str, url: str): + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + method, + url % self.room_id, + content={}, + access_token=self.other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"), + ] + ) + def test_room_does_not_exist(self, method: str, url: str): + """ + Check that unknown rooms/server return error 404. + """ + + channel = self.make_request( + method, + url % "!unknown:test", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + @parameterized.expand( + [ + ("DELETE", "/_synapse/admin/v2/rooms/%s"), + ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"), + ] + ) + def test_room_is_not_valid(self, method: str, url: str): + """ + Check that invalid room names, return an error 400. + """ + + channel = self.make_request( + method, + url % "invalidroom", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "invalidroom is not a legal room ID", + channel.json_body["error"], + ) + + def test_new_room_user_does_not_exist(self): + """ + Tests that the user ID must be from local server but it does not have to exist. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@unknown:test"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + def test_new_room_user_is_not_local(self): + """ + Check that only local users can create new room to move members. + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": "@not:exist.bla"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + "User must be our own: @not:exist.bla", + channel.json_body["error"], ) + def test_block_is_not_bool(self): + """ + If parameter `block` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"block": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + + channel = self.make_request( + "DELETE", + self.url, + content={"purge": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + def test_delete_expired_status(self): + """Test that the task status is removed after expiration.""" + + # first task, do not purge, that we can create a second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": False}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id1 = channel.json_body["delete_id"] + + # go ahead + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + # second task + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id2 = channel.json_body["delete_id"] + + # get status + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(2, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual("complete", channel.json_body["results"][1]["status"]) + self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"]) + self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"]) + + # get status after more than clearing time for first task + # second task is not cleared + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + self.assertEqual("complete", channel.json_body["results"][0]["status"]) + self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"]) + + # get status after more than clearing time for all tasks + self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_delete_same_room_twice(self): + """Test that the call for delete a room at second time gives an exception.""" + + body = {"new_room_user_id": self.admin_user} + + # first call to delete room + # and do not wait for finish the task + first_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + await_result=False, + ) + + # second call to delete room + second_channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content=body, + access_token=self.admin_user_tok, + ) + + self.assertEqual( + HTTPStatus.BAD_REQUEST, second_channel.code, msg=second_channel.json_body + ) + self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"]) + self.assertEqual( + f"History purge already in progress for {self.room_id}", + second_channel.json_body["error"], + ) + + # get result of first call + first_channel.await_result() + self.assertEqual(HTTPStatus.OK, first_channel.code, msg=first_channel.json_body) + self.assertIn("delete_id", first_channel.json_body) + + # check status after finish the task + self._test_result( + first_channel.json_body["delete_id"], + self.other_user, + expect_new_room=True, + ) + + def test_purge_room_and_block(self): + """Test to purge a room and block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_purge_room_and_not_block(self): + """Test to purge a room and do not block it. + Members will not be moved to a new room and will not receive a message. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": False, "purge": True}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + channel = self.make_request( + "DELETE", + self.url.encode("ascii"), + content={"block": True, "purge": False}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=True) + self._has_no_members(self.room_id) + + def test_shutdown_room_consent(self): + """Test that we can shutdown rooms with local users who have not + yet accepted the privacy policy. This used to fail when we tried to + force part the user from the old room. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Assert one user in room + users_in_room = self.get_success(self.store.get_users_in_room(self.room_id)) + self.assertEqual([self.other_user], users_in_room) + + # Enable require consent to send events + self.event_creation_handler._block_events_without_consent_error = "Error" + + # Assert that the user is getting consent error + self.helper.send( + self.room_id, body="foo", tok=self.other_user_tok, expect_code=403 + ) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + def test_shutdown_room_block_peek(self): + """Test that a world_readable room can no longer be peeked into after + it has been shut down. + Members will be moved to a new room and will receive a message. + """ + self.event_creation_handler._block_events_without_consent_error = None + + # Enable world readable + url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,) + channel = self.make_request( + "PUT", + url.encode("ascii"), + content={"history_visibility": "world_readable"}, + access_token=self.other_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + # Test that the admin can still send shutdown + channel = self.make_request( + "DELETE", + self.url, + content={"new_room_user_id": self.admin_user}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertIn("delete_id", channel.json_body) + delete_id = channel.json_body["delete_id"] + + self._test_result(delete_id, self.other_user, expect_new_room=True) + + channel = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Test that member has moved to new room + self._is_member( + room_id=channel.json_body["results"][0]["shutdown_room"]["new_room_id"], + user_id=self.other_user, + ) + + self._is_purged(self.room_id) + self._has_no_members(self.room_id) + + # Assert we can no longer peek into the room + self._assert_peek(self.room_id, expect_code=403) + + def _is_blocked(self, room_id: str, expect: bool = True) -> None: + """Assert that the room is blocked or not""" + d = self.store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _has_no_members(self, room_id: str) -> None: + """Assert there is now no longer anyone in the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual([], users_in_room) + + def _is_member(self, room_id: str, user_id: str) -> None: + """Test that user is member of the room""" + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) + self.assertIn(user_id, users_in_room) + + def _is_purged(self, room_id: str) -> None: + """Test that the following tables have been purged of all rows related to the room.""" + for table in PURGE_TABLES: + count = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table=table, + keyvalues={"room_id": room_id}, + retcol="COUNT(*)", + desc="test_purge_room", + ) + ) + + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") + + def _assert_peek(self, room_id: str, expect_code: int) -> None: + """Assert that the admin user can (or cannot) peek into the room.""" + + url = f"rooms/{room_id}/initialSync" + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok + ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + url = "events?timeout=0&room_id=" + room_id channel = self.make_request( "GET", url.encode("ascii"), access_token=self.admin_user_tok ) + self.assertEqual(expect_code, channel.code, msg=channel.json_body) + + def _test_result( + self, + delete_id: str, + kicked_user: str, + expect_new_room: bool = False, + ) -> None: + """ + Test that the result is the expected. + Uses both APIs (status by room_id and delete_id) + + Args: + delete_id: id of this purge + kicked_user: a user_id which is kicked from the room + expect_new_room: if we expect that a new room was created + """ + + # get information by room_id + channel_room_id = self.make_request( + "GET", + self.url_status_by_room_id, + access_token=self.admin_user_tok, + ) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] + HTTPStatus.OK, channel_room_id.code, msg=channel_room_id.json_body ) + self.assertEqual(1, len(channel_room_id.json_body["results"])) + self.assertEqual( + delete_id, channel_room_id.json_body["results"][0]["delete_id"] + ) + + # get information by delete_id + channel_delete_id = self.make_request( + "GET", + self.url_status_by_delete_id + delete_id, + access_token=self.admin_user_tok, + ) + self.assertEqual( + HTTPStatus.OK, + channel_delete_id.code, + msg=channel_delete_id.json_body, + ) + + # test values that are the same in both responses + for content in [ + channel_room_id.json_body["results"][0], + channel_delete_id.json_body, + ]: + self.assertEqual("complete", content["status"]) + self.assertEqual(kicked_user, content["shutdown_room"]["kicked_users"][0]) + self.assertIn("failed_to_kick_users", content["shutdown_room"]) + self.assertIn("local_aliases", content["shutdown_room"]) + self.assertNotIn("error", content) + + if expect_new_room: + self.assertIsNotNone(content["shutdown_room"]["new_room_id"]) + else: + self.assertIsNone(content["shutdown_room"]["new_room_id"]) class RoomTestCase(unittest.HomeserverTestCase): @@ -466,7 +1094,7 @@ class RoomTestCase(unittest.HomeserverTestCase): ) # Check request completed successfully - self.assertEqual(200, int(channel.code), msg=channel.json_body) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that response json body contains a "rooms" key self.assertTrue( @@ -550,9 +1178,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertTrue("rooms" in channel.json_body) for r in channel.json_body["rooms"]: @@ -592,7 +1218,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_correct_room_attributes(self): """Test the correct attributes for a room are returned""" @@ -615,7 +1241,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -647,7 +1273,7 @@ class RoomTestCase(unittest.HomeserverTestCase): url.encode("ascii"), access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Check that rooms were returned self.assertTrue("rooms" in channel.json_body) @@ -1107,7 +1733,7 @@ class RoomTestCase(unittest.HomeserverTestCase): {"room_id": room_id}, access_token=admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Set this new alias as the canonical alias for this room self.helper.send_state( @@ -1157,11 +1783,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.second_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_invalid_parameter(self): @@ -1173,11 +1799,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) def test_local_user_does_not_exist(self): @@ -1189,11 +1815,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) def test_remote_user(self): @@ -1205,11 +1831,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "This endpoint can only be used with local users", channel.json_body["error"], @@ -1225,11 +1851,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("No known servers", channel.json_body["error"]) def test_room_is_not_valid(self): @@ -1242,11 +1868,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "invalidroom was not legal room ID or room alias", channel.json_body["error"], @@ -1261,11 +1887,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1275,7 +1901,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_not_member(self): @@ -1292,11 +1918,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_join_private_room_if_member(self): @@ -1324,7 +1950,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) # Join user to room. @@ -1335,10 +1961,10 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1348,7 +1974,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_join_private_room_if_owner(self): @@ -1365,11 +1991,11 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, - content=body.encode(encoding="utf_8"), + content=body, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["room_id"]) # Validate if user is a member of the room @@ -1379,7 +2005,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/joined_rooms", access_token=self.second_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) def test_context_as_non_admin(self): @@ -1413,9 +2039,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=tok, ) - self.assertEquals( - 403, int(channel.result["code"]), msg=channel.result["body"] - ) + self.assertEquals(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_context_as_admin(self): @@ -1445,7 +2069,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): % (room_id, events[midway]["event_id"]), access_token=self.admin_user_tok, ) - self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEquals(200, channel.code, msg=channel.json_body) self.assertEquals( channel.json_body["event"]["event_id"], events[midway]["event_id"] ) @@ -1504,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) @@ -1531,7 +2155,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room (we should have received an # invite) and can ban a user. @@ -1557,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Now we test that we can join the room and ban a user. self.helper.join(room_id, self.second_user_id, tok=self.second_tok) @@ -1595,13 +2219,241 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): # # (Note we assert the error message to ensure that it's not denied for # some other reason) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( channel.json_body["error"], "No local admin user in room with power to update power levels.", ) +class BlockRoomTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self._store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok + ) + self.url = "/_synapse/admin/v1/rooms/%s/block" + + @parameterized.expand([("PUT",), ("GET",)]) + def test_requester_is_no_admin(self, method: str): + """If the user is not a server admin, an error 403 is returned.""" + + channel = self.make_request( + method, + self.url % self.room_id, + content={}, + access_token=self.other_user_tok, + ) + + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @parameterized.expand([("PUT",), ("GET",)]) + def test_room_is_not_valid(self, method: str): + """Check that invalid room names, return an error 400.""" + + channel = self.make_request( + method, + self.url % "invalidroom", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual( + "invalidroom is not a legal room ID", + channel.json_body["error"], + ) + + def test_block_is_not_valid(self): + """If parameter `block` is not valid, return an error.""" + + # `block` is not valid + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={"block": "NotBool"}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + + # `block` is not set + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # no content is send + channel = self.make_request( + "PUT", + self.url % self.room_id, + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) + + def test_block_room(self): + """Test that block a room is successful.""" + + def _request_and_test_block_room(room_id: str) -> None: + self._is_blocked(room_id, expect=False) + channel = self.make_request( + "PUT", + self.url % room_id, + content={"block": True}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertTrue(channel.json_body["block"]) + self._is_blocked(room_id, expect=True) + + # known internal room + _request_and_test_block_room(self.room_id) + + # unknown internal room + _request_and_test_block_room("!unknown:test") + + # unknown remote room + _request_and_test_block_room("!unknown:remote") + + def test_block_room_twice(self): + """Test that block a room that is already blocked is successful.""" + + self._is_blocked(self.room_id, expect=False) + for _ in range(2): + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={"block": True}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertTrue(channel.json_body["block"]) + self._is_blocked(self.room_id, expect=True) + + def test_unblock_room(self): + """Test that unblock a room is successful.""" + + def _request_and_test_unblock_room(room_id: str) -> None: + self._block_room(room_id) + + channel = self.make_request( + "PUT", + self.url % room_id, + content={"block": False}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body["block"]) + self._is_blocked(room_id, expect=False) + + # known internal room + _request_and_test_unblock_room(self.room_id) + + # unknown internal room + _request_and_test_unblock_room("!unknown:test") + + # unknown remote room + _request_and_test_unblock_room("!unknown:remote") + + def test_unblock_room_twice(self): + """Test that unblock a room that is not blocked is successful.""" + + self._block_room(self.room_id) + for _ in range(2): + channel = self.make_request( + "PUT", + self.url % self.room_id, + content={"block": False}, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body["block"]) + self._is_blocked(self.room_id, expect=False) + + def test_get_blocked_room(self): + """Test get status of a blocked room""" + + def _request_blocked_room(room_id: str) -> None: + self._block_room(room_id) + + channel = self.make_request( + "GET", + self.url % room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertTrue(channel.json_body["block"]) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + + # known internal room + _request_blocked_room(self.room_id) + + # unknown internal room + _request_blocked_room("!unknown:test") + + # unknown remote room + _request_blocked_room("!unknown:remote") + + def test_get_unblocked_room(self): + """Test get status of a unblocked room""" + + def _request_unblocked_room(room_id: str) -> None: + self._is_blocked(room_id, expect=False) + + channel = self.make_request( + "GET", + self.url % room_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body["block"]) + self.assertNotIn("user_id", channel.json_body) + + # known internal room + _request_unblocked_room(self.room_id) + + # unknown internal room + _request_unblocked_room("!unknown:test") + + # unknown remote room + _request_unblocked_room("!unknown:remote") + + def _is_blocked(self, room_id: str, expect: bool = True) -> None: + """Assert that the room is blocked or not""" + d = self._store.is_room_blocked(room_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertIsNone(self.get_success(d)) + + def _block_room(self, room_id: str) -> None: + """Block a room in database""" + self.get_success(self._store.block_room(room_id, self.other_user)) + self._is_blocked(room_id, expect=True) + + PURGE_TABLES = [ "current_state_events", "event_backward_extremities", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 25e8d6cf27..5011e54563 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1169,14 +1169,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): # regardless of whether password login or SSO is allowed self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.admin_user, device_id=None, valid_until_ms=None ) ) self.other_user = self.register_user("user", "pass", displayname="User") self.other_user_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.other_user, device_id=None, valid_until_ms=None ) ) @@ -3592,31 +3592,34 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): + @parameterized.expand(["POST", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get information of an user without authentication. """ - channel = self.make_request("POST", self.url) + channel = self.make_request(method, self.url) self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_not_admin(self): + @parameterized.expand(["POST", "DELETE"]) + def test_requester_is_not_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ other_user_token = self.login("user", "pass") - channel = self.make_request("POST", self.url, access_token=other_user_token) + channel = self.make_request(method, self.url, access_token=other_user_token) self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_is_not_local(self): + @parameterized.expand(["POST", "DELETE"]) + def test_user_is_not_local(self, method: str): """ Tests that shadow-banning for a user that is not a local returns a 400 """ url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" - channel = self.make_request("POST", url, access_token=self.admin_user_tok) + channel = self.make_request(method, url, access_token=self.admin_user_tok) self.assertEqual(400, channel.code, msg=channel.json_body) def test_success(self): @@ -3636,6 +3639,17 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): result = self.get_success(self.store.get_user_by_access_token(other_user_token)) self.assertTrue(result.shadow_banned) + # Un-shadow-ban the user. + channel = self.make_request( + "DELETE", self.url, access_token=self.admin_user_tok + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + # Ensure the user is no longer shadow-banned (and the cache was cleared). + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertFalse(result.shadow_banned) + class RateLimitTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index b9e3602552..249808b031 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"localdb_enabled": False}}) def test_get_change_password_capabilities_localdb_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"enabled": False}}) def test_get_change_password_capabilities_password_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"experimental_features": {"msc3244_enabled": False}}) def test_get_does_not_include_msc3244_fields_when_disabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) @@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): def test_get_does_include_msc3244_fields_when_enabled(self): access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( + self.auth_handler.create_access_token_for_user_id( self.user, device_id=None, valid_until_ms=None ) ) diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index d2181ea907..aca03afd0e 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -11,12 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin from synapse.rest.client import directory, login, room +from synapse.server import HomeServer from synapse.types import RoomAlias +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -32,7 +36,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["require_membership_for_aliases"] = True @@ -40,7 +44,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase): return self.hs - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + """Create two local users and access tokens for them. + One of them creates a room.""" self.room_owner = self.register_user("room_owner", "test") self.room_owner_tok = self.login("room_owner", "test") @@ -51,39 +59,39 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "test") self.user_tok = self.login("user", "test") - def test_state_event_not_in_room(self): + def test_state_event_not_in_room(self) -> None: self.ensure_user_left_room() - self.set_alias_via_state_event(403) + self.set_alias_via_state_event(HTTPStatus.FORBIDDEN) - def test_directory_endpoint_not_in_room(self): + def test_directory_endpoint_not_in_room(self) -> None: self.ensure_user_left_room() - self.set_alias_via_directory(403) + self.set_alias_via_directory(HTTPStatus.FORBIDDEN) - def test_state_event_in_room_too_long(self): + def test_state_event_in_room_too_long(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_state_event(400, alias_length=256) + self.set_alias_via_state_event(HTTPStatus.BAD_REQUEST, alias_length=256) - def test_directory_in_room_too_long(self): + def test_directory_in_room_too_long(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_directory(400, alias_length=256) + self.set_alias_via_directory(HTTPStatus.BAD_REQUEST, alias_length=256) @override_config({"default_room_version": 5}) - def test_state_event_user_in_v5_room(self): + def test_state_event_user_in_v5_room(self) -> None: """Test that a regular user can add alias events before room v6""" self.ensure_user_joined_room() - self.set_alias_via_state_event(200) + self.set_alias_via_state_event(HTTPStatus.OK) @override_config({"default_room_version": 6}) - def test_state_event_v6_room(self): + def test_state_event_v6_room(self) -> None: """Test that a regular user can *not* add alias events from room v6""" self.ensure_user_joined_room() - self.set_alias_via_state_event(403) + self.set_alias_via_state_event(HTTPStatus.FORBIDDEN) - def test_directory_in_room(self): + def test_directory_in_room(self) -> None: self.ensure_user_joined_room() - self.set_alias_via_directory(200) + self.set_alias_via_directory(HTTPStatus.OK) - def test_room_creation_too_long(self): + def test_room_creation_too_long(self) -> None: url = "/_matrix/client/r0/createRoom" # We use deliberately a localpart under the length threshold so @@ -93,9 +101,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - def test_room_creation(self): + def test_room_creation(self) -> None: url = "/_matrix/client/r0/createRoom" # Check with an alias of allowed length. There should already be @@ -106,9 +114,46 @@ class DirectoryTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + def test_deleting_alias_via_directory(self) -> None: + # Add an alias for the room. We must be joined to do so. + self.ensure_user_joined_room() + alias = self.set_alias_via_directory(HTTPStatus.OK) + + # Then try to remove the alias + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + def test_deleting_nonexistant_alias(self) -> None: + # Check that no alias exists + alias = "#potato:test" + channel = self.make_request( + "GET", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body) + + # Then try to remove the alias + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND", channel.json_body) - def set_alias_via_state_event(self, expected_code, alias_length=5): + def set_alias_via_state_event( + self, expected_code: HTTPStatus, alias_length: int = 5 + ) -> None: url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( self.room_id, self.hs.hostname, @@ -122,8 +167,11 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, expected_code, channel.result) - def set_alias_via_directory(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) + def set_alias_via_directory( + self, expected_code: HTTPStatus, alias_length: int = 5 + ) -> str: + alias = self.random_alias(alias_length) + url = "/_matrix/client/r0/directory/room/%s" % alias data = {"room_id": self.room_id} request_data = json.dumps(data) @@ -131,17 +179,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase): "PUT", url, request_data, access_token=self.user_tok ) self.assertEqual(channel.code, expected_code, channel.result) + return alias - def random_alias(self, length): + def random_alias(self, length: int) -> str: return RoomAlias(random_string(length), self.hs.hostname).to_string() - def ensure_user_left_room(self): + def ensure_user_left_room(self) -> None: self.ensure_membership("leave") - def ensure_user_joined_room(self): + def ensure_user_joined_room(self) -> None: self.ensure_membership("join") - def ensure_membership(self, membership): + def ensure_membership(self, membership: str) -> None: try: if membership == "leave": self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index a63f04bd41..19f5e46537 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -79,7 +79,10 @@ EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fΓΆ&=o"')] # (possibly experimental) login flows we expect to appear in the list after the normal # ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] +ADDITIONAL_LOGIN_FLOWS = [ + {"type": "m.login.application_service"}, + {"type": "uk.half-shot.msc2778.login.application_service"}, +] class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -812,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase): jwt_secret = "secret" jwt_algorithm = "HS256" + base_config = { + "enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + } - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt.jwt_enabled = True - self.hs.config.jwt.jwt_secret = self.jwt_secret - self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm - return self.hs + def default_config(self): + config = super().default_config() + + # If jwt_config has been defined (eg via @override_config), don't replace it. + if config.get("jwt_config") is None: + config["jwt_config"] = self.base_config + + return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. @@ -876,16 +886,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "issuer": "test-issuer", - } - } - ) + @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}}) def test_login_iss(self): """Test validating the issuer claim.""" # A valid issuer. @@ -916,16 +917,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"200", channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "audiences": ["test-audience"], - } - } - ) + @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) def test_login_aud(self): """Test validating the audience claim.""" # A valid audience. @@ -959,6 +951,19 @@ class JWTTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Invalid audience" ) + def test_login_default_sub(self): + """Test reading user ID from the default subject claim.""" + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) + def test_login_custom_sub(self): + """Test reading user ID from a custom subject claim.""" + channel = self.jwt_login({"username": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + def test_login_no_token(self): params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) @@ -1021,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): ] ) - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt.jwt_enabled = True - self.hs.config.jwt.jwt_secret = self.jwt_pubkey - self.hs.config.jwt.jwt_algorithm = "RS256" - return self.hs + def default_config(self): + config = super().default_config() + config["jwt_config"] = { + "enabled": True, + "secret": self.jwt_pubkey, + "algorithm": "RS256", + } + return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 78c2fb86b9..eb10d43217 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1,4 +1,5 @@ # Copyright 2019 New Vector Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): return config def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id, self.user_token = self._create_user("alice") self.user2_id, self.user2_token = self._create_user("bob") @@ -91,6 +94,49 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) self.assertEquals(400, channel.code, channel.json_body) + def test_deny_invalid_event(self): + """Test that we deny relations on non-existant events""" + channel = self._send_relation( + RelationTypes.ANNOTATION, + EventTypes.Message, + parent_id="foo", + content={"body": "foo", "msgtype": "m.text"}, + ) + self.assertEquals(400, channel.code, channel.json_body) + + # Unless that event is referenced from another event! + self.get_success( + self.hs.get_datastore().db_pool.simple_insert( + table="event_relations", + values={ + "event_id": "bar", + "relates_to_id": "foo", + "relation_type": RelationTypes.THREAD, + }, + desc="test_deny_invalid_event", + ) + ) + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + parent_id="foo", + content={"body": "foo", "msgtype": "m.text"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + def test_deny_invalid_room(self): + """Test that we deny relations on non-existant events""" + # Create another room and send a message in it. + room2 = self.helper.create_room_as(self.user_id, tok=self.user_token) + res = self.helper.send(room2, body="Hi!", tok=self.user_token) + parent_id = res["event_id"] + + # Attempt to send an annotation to that event. + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" + ) + self.assertEquals(400, channel.code, channel.json_body) + def test_deny_double_react(self): """Test that we deny relations on membership events""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") @@ -99,6 +145,25 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(400, channel.code, channel.json_body) + def test_deny_forked_thread(self): + """It is invalid to start a thread off a thread.""" + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "foo"}, + parent_id=self.parent_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + parent_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "foo"}, + parent_id=parent_id, + ) + self.assertEquals(400, channel.code, channel.json_body) + def test_basic_paginate_relations(self): """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -703,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertIn("chunk", channel.json_body) self.assertEquals(channel.json_body["chunk"], []) + def test_unknown_relations(self): + """Unknown relations should be accepted.""" + channel = self._send_relation("m.relation.test", "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"}, + channel.json_body["chunk"][0], + ) + + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEquals( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + + # When bundling the unknown relation is not included. + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + + # But unknown relations can be directly queried. + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + def _send_relation( self, relation_type: str, @@ -749,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token = self.login(localpart, "abc123") return user_id, access_token + + def test_background_update(self): + """Test the event_arbitrary_relations background update.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="π") + self.assertEquals(200, channel.code, channel.json_body) + annotation_event_id_good = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") + self.assertEquals(200, channel.code, channel.json_body) + annotation_event_id_bad = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + thread_event_id = channel.json_body["event_id"] + + # Clean-up the table as if the inserts did not happen during event creation. + self.get_success( + self.store.db_pool.simple_delete_many( + table="event_relations", + column="event_id", + iterable=(annotation_event_id_bad, thread_event_id), + keyvalues={}, + desc="RelationsTestCase.test_background_update", + ) + ) + + # Only the "good" annotation should be found. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertEquals( + [ev["event_id"] for ev in channel.json_body["chunk"]], + [annotation_event_id_good], + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "event_arbitrary_relations", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + self.wait_for_background_updates() + + # The "good" annotation and the thread should be found, but not the "bad" + # annotation. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + self.assertCountEqual( + [ev["event_id"] for ev in channel.json_body["chunk"]], + [annotation_event_id_good, thread_event_id], + ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 376853fd65..10a4a4dc5e 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -25,7 +25,12 @@ from urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin @@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase): return event_id +class RelationsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["experimental_features"] = {"msc3440_enabled": True} + return config + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + self.helper.join( + room=self.room_id, user=self.second_user_id, tok=self.second_tok + ) + + self.third_user_id = self.register_user("third", "test") + self.third_tok = self.login("third", "test") + self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok) + + # An initial event with a relation from second user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 1"}, + tok=self.tok, + ) + self.event_id_1 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": self.event_id_1, + "key": "π", + } + }, + tok=self.second_tok, + ) + + # Another event with a relation from third user. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "Message 2"}, + tok=self.tok, + ) + self.event_id_2 = res["event_id"] + self.helper.send_event( + room_id=self.room_id, + type="m.reaction", + content={ + "m.relates_to": { + "rel_type": RelationTypes.REFERENCE, + "event_id": self.event_id_2, + } + }, + tok=self.third_tok, + ) + + # An event with no relations. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "No relations"}, + tok=self.tok, + ) + + def _filter_messages(self, filter: JsonDict) -> List[JsonDict]: + """Make a request to /messages with a filter, returns the chunk of events.""" + channel = self.make_request( + "GET", + "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["chunk"] + + def test_filter_relation_senders(self): + # Messages which second user reacted to. + filter = {"io.element.relation_senders": [self.second_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_1) + + # Messages which third user reacted to. + filter = {"io.element.relation_senders": [self.third_user_id]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_2) + + # Messages which either user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id, self.third_user_id] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_type(self): + # Messages which have annotations. + filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_1) + + # Messages which have references. + filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_2) + + # Messages which have either annotations or references. + filter = { + "io.element.relation_types": [ + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + ] + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 2, chunk) + self.assertCountEqual( + [c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2] + ) + + def test_filter_relation_senders_and_type(self): + # Messages which second user reacted to. + filter = { + "io.element.relation_senders": [self.second_user_id], + "io.element.relation_types": [RelationTypes.ANNOTATION], + } + chunk = self._filter_messages(filter) + self.assertEqual(len(chunk), 1, chunk) + self.assertEqual(chunk[0]["event_id"], self.event_id_1) + + class ContextTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index ec0979850b..1af5e5cee5 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -19,10 +19,21 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union +from typing import ( + Any, + AnyStr, + Dict, + Iterable, + Mapping, + MutableMapping, + Optional, + Tuple, + overload, +) from unittest.mock import patch import attr +from typing_extensions import Literal from twisted.web.resource import Resource from twisted.web.server import Site @@ -45,6 +56,32 @@ class RestHelper: site = attr.ib(type=Site) auth_user_id = attr.ib() + @overload + def create_room_as( + self, + room_creator: Optional[str] = ..., + is_public: Optional[bool] = ..., + room_version: Optional[str] = ..., + tok: Optional[str] = ..., + expect_code: Literal[200] = ..., + extra_content: Optional[Dict] = ..., + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., + ) -> str: + ... + + @overload + def create_room_as( + self, + room_creator: Optional[str] = ..., + is_public: Optional[bool] = ..., + room_version: Optional[str] = ..., + tok: Optional[str] = ..., + expect_code: int = ..., + extra_content: Optional[Dict] = ..., + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., + ) -> Optional[str]: + ... + def create_room_as( self, room_creator: Optional[str] = None, @@ -53,10 +90,8 @@ class RestHelper: tok: Optional[str] = None, expect_code: int = 200, extra_content: Optional[Dict] = None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ) -> str: + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + ) -> Optional[str]: """ Create a room. @@ -99,6 +134,8 @@ class RestHelper: if expect_code == 200: return channel.json_body["room_id"] + else: + return None def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): self.change_membership( @@ -168,7 +205,7 @@ class RestHelper: extra_data: Optional[dict] = None, tok: Optional[str] = None, expect_code: int = 200, - expect_errcode: str = None, + expect_errcode: Optional[str] = None, ) -> None: """ Send a membership state event into a room. @@ -227,9 +264,7 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): if body is None: body = "body_text_here" @@ -254,9 +289,7 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -418,7 +451,7 @@ class RestHelper: path, content=image_data, access_token=tok, - custom_headers=[(b"Content-Length", str(image_length))], + custom_headers=[("Content-Length", str(image_length))], ) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( @@ -503,7 +536,7 @@ class RestHelper: went. """ - cookies = {} + cookies: Dict[str, str] = {} # if we're doing a ui auth, hit the ui auth redirect endpoint if ui_auth_session_id: @@ -625,7 +658,13 @@ class RestHelper: # hit the redirect url again with the right Host header, which should now issue # a cookie and redirect to the SSO provider. - location = channel.headers.getRawHeaders("Location")[0] + def get_location(channel: FakeChannel) -> str: + location_values = channel.headers.getRawHeaders("Location") + # Keep mypy happy by asserting that location_values is nonempty + assert location_values + return location_values[0] + + location = get_location(channel) parts = urllib.parse.urlsplit(location) channel = make_request( self.hs.get_reactor(), @@ -639,7 +678,7 @@ class RestHelper: assert channel.code == 302 channel.extract_cookies(cookies) - return channel.headers.getRawHeaders("Location")[0] + return get_location(channel) def initiate_sso_ui_auth( self, ui_auth_session_id: str, cookies: MutableMapping[str, str] |