diff options
Diffstat (limited to 'tests/rest')
-rw-r--r-- | tests/rest/admin/test_admin.py | 6 | ||||
-rw-r--r-- | tests/rest/admin/test_room.py | 21 | ||||
-rw-r--r-- | tests/rest/client/test_third_party_rules.py | 136 | ||||
-rw-r--r-- | tests/rest/client/v1/test_login.py | 14 | ||||
-rw-r--r-- | tests/rest/client/v1/test_rooms.py | 14 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 30 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_relations.py | 2 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_report_event.py | 2 | ||||
-rw-r--r-- | tests/rest/media/v1/test_media_storage.py | 2 |
9 files changed, 165 insertions, 62 deletions
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 2f7090e554..a7c6e595b9 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -66,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): # Create a new group channel = self.make_request( "POST", - "/create_group".encode("ascii"), + b"/create_group", access_token=self.admin_user_tok, content={"localpart": "test"}, ) @@ -129,9 +129,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token)""" - channel = self.make_request( - "GET", "/joined_groups".encode("ascii"), access_token=access_token - ) + channel = self.make_request("GET", b"/joined_groups", access_token=access_token) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ee071c2477..17ec8bfd3b 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -535,7 +535,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") def _assert_peek(self, room_id, expect_code): """Assert that the admin user can (or cannot) peek into the room.""" @@ -599,7 +599,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): ) ) - self.assertEqual(count, 0, msg="Rows not purged in {}".format(table)) + self.assertEqual(count, 0, msg=f"Rows not purged in {table}") class RoomTestCase(unittest.HomeserverTestCase): @@ -1280,7 +1280,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): self.public_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=True ) - self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id) + self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" def test_requester_is_no_admin(self): """ @@ -1420,7 +1420,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): private_room_id = self.helper.create_room_as( self.creator, tok=self.creator_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1463,7 +1463,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): # Join user to room. - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1493,7 +1493,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase): private_room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok, is_public=False ) - url = "/_synapse/admin/v1/join/{}".format(private_room_id) + url = f"/_synapse/admin/v1/join/{private_room_id}" body = json.dumps({"user_id": self.second_user_id}) channel = self.make_request( @@ -1633,7 +1633,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1660,7 +1660,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1686,7 +1686,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={"user_id": self.second_user_id}, access_token=self.admin_user_tok, ) @@ -1720,7 +1720,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id), + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", content={}, access_token=self.admin_user_tok, ) @@ -1753,7 +1753,6 @@ PURGE_TABLES = [ "room_memberships", "room_stats_state", "room_stats_current", - "room_stats_historical", "room_stats_earliest_token", "rooms", "stream_ordering_to_exterm", diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index e1fe72fc5d..28dd47a28b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -16,17 +16,19 @@ from typing import Dict from unittest.mock import Mock from synapse.events import EventBase +from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import Requester, StateMap +from synapse.util.frozenutils import unfreeze from tests import unittest thread_local = threading.local() -class ThirdPartyRulesTestModule: +class LegacyThirdPartyRulesTestModule: def __init__(self, config: Dict, module_api: ModuleApi): # keep a record of the "current" rules module, so that the test can patch # it if desired. @@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule: return config -def current_rules_module() -> ThirdPartyRulesTestModule: - return thread_local.rules_module +class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ): + return False + + +class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): + def __init__(self, config: Dict, module_api: ModuleApi): + super().__init__(config, module_api) + + async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): + d = event.get_dict() + content = unfreeze(event.content) + content["foo"] = "bar" + d["content"] = content + return d class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): @@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def default_config(self): - config = super().default_config() - config["third_party_event_rules"] = { - "module": __name__ + ".ThirdPartyRulesTestModule", - "config": {}, - } - return config + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + + load_legacy_third_party_event_rules(hs) + + return hs def prepare(self, reactor, clock, homeserver): # Create a user and room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.tok = self.login("kermit", "monkey") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + # Some tests might prevent room creation on purpose. + try: + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + except Exception: + pass def test_third_party_rules(self): """Tests that a forbidden event is forbidden from being sent, but an allowed one @@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): # patch the rules module with a Mock which will return False for some event # types async def check(ev, state): - return ev.type != "foo.bar.forbidden" + return ev.type != "foo.bar.forbidden", None callback = Mock(spec=[], side_effect=check) - current_rules_module().check_event_allowed = callback + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [ + callback + ] channel = self.make_request( "PUT", @@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): # first patch the event checker so that it will try to modify the event async def check(ev: EventBase, state): ev.content = {"x": "y"} - return True + return True, None - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): {"x": "x"}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"500", channel.result) + # check_event_allowed has some error handling, so it shouldn't 500 just because a + # module did something bad. + self.assertEqual(channel.code, 200, channel.result) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["x"], "x") def test_modify_event(self): """The module can return a modified version of the event""" @@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): async def check(ev: EventBase, state): d = ev.get_dict() d["content"] = {"x": "y"} - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # now send the event channel = self.make_request( @@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "msgtype": "m.text", "body": d["content"]["body"].upper(), } - return d + return True, d - current_rules_module().check_event_allowed = check + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] # Send an event, then edit it. channel = self.make_request( @@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): self.assertEqual(ev["content"]["body"], "EDITED BODY") def test_send_event(self): - """Tests that the module can send an event into a room via the module api""" + """Tests that a module can send an event into a room via the module api""" content = { "msgtype": "m.text", "body": "Hello!", @@ -233,13 +270,60 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): "content": content, "sender": self.user_id, } - event = self.get_success( - current_rules_module().module_api.create_and_send_event_into_room( - event_dict - ) - ) # type: EventBase + event: EventBase = self.get_success( + self.hs.get_module_api().create_and_send_event_into_room(event_dict) + ) self.assertEquals(event.sender, self.user_id) self.assertEquals(event.room_id, self.room_id) self.assertEquals(event.type, "m.room.message") self.assertEquals(event.content, content) + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyChangeEvents", + "config": {}, + } + } + ) + def test_legacy_check_event_allowed(self): + """Tests that the wrapper for legacy check_event_allowed callbacks works + correctly. + """ + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id, + { + "msgtype": "m.text", + "body": "Original body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + + self.assertIn("foo", channel.json_body["content"].keys()) + self.assertEqual(channel.json_body["content"]["foo"], "bar") + + @unittest.override_config( + { + "third_party_event_rules": { + "module": __name__ + ".LegacyDenyNewRooms", + "config": {}, + } + } + ) + def test_legacy_on_create_room(self): + """Tests that the wrapper for legacy on_create_room callbacks works + correctly. + """ + self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 605b952316..7eba69642a 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -453,7 +453,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # stick the flows results in a dict by type - flow_results = {} # type: Dict[str, Any] + flow_results: Dict[str, Any] = {} for f in channel.json_body["flows"]: flow_type = f["type"] self.assertNotIn( @@ -501,7 +501,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): p.close() # there should be a link for each href - returned_idps = [] # type: List[str] + returned_idps: List[str] = [] for link in p.links: path, query = link.split("?", 1) self.assertEqual(path, "pick_idp") @@ -582,7 +582,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # ... and should have set a cookie including the redirect url cookie_headers = channel.headers.getRawHeaders("Set-Cookie") assert cookie_headers - cookies = {} # type: Dict[str, str] + cookies: Dict[str, str] = {} for h in cookie_headers: key, value = h.split(";")[0].split("=", maxsplit=1) cookies[key] = value @@ -874,9 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode( - payload, secret, self.jwt_algorithm - ) # type: Union[str, bytes] + result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) if isinstance(result, bytes): return result.decode("ascii") return result @@ -1084,7 +1082,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str] + result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") if isinstance(result, bytes): return result.decode("ascii") return result @@ -1272,7 +1270,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie - cookies = {} # type: Dict[str,str] + cookies: Dict[str, str] = {} channel.extract_cookies(cookies) self.assertIn("username_mapping_session", cookies) session_id = cookies["username_mapping_session"] diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e94566ffd7..3df070c936 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1206,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/join", content={"reason": reason}, access_token=self.second_tok, ) @@ -1220,7 +1220,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) @@ -1234,7 +1234,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/kick", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) @@ -1248,7 +1248,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/ban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1260,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/unban", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1272,7 +1272,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/invite", content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) @@ -1291,7 +1291,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): reason = "hello" channel = self.make_request( "POST", - "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + f"/_matrix/client/r0/rooms/{self.room_id}/leave", content={"reason": reason}, access_token=self.second_tok, ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 69798e95c3..fc2d35596e 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -19,7 +19,7 @@ import json import re import time import urllib.parse -from typing import Any, Dict, Mapping, MutableMapping, Optional +from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union from unittest.mock import patch import attr @@ -53,6 +53,9 @@ class RestHelper: tok: str = None, expect_code: int = 200, extra_content: Optional[Dict] = None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ) -> str: """ Create a room. @@ -87,6 +90,7 @@ class RestHelper: "POST", path, json.dumps(content).encode("utf8"), + custom_headers=custom_headers, ) assert channel.result["code"] == b"%d" % expect_code, channel.result @@ -175,14 +179,30 @@ class RestHelper: self.auth_user_id = temp_id - def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): + def send( + self, + room_id, + body=None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): if body is None: body = "body_text_here" content = {"msgtype": "m.text", "body": body} return self.send_event( - room_id, "m.room.message", content, txn_id, tok, expect_code + room_id, + "m.room.message", + content, + txn_id, + tok, + expect_code, + custom_headers=custom_headers, ) def send_event( @@ -193,6 +213,9 @@ class RestHelper: txn_id=None, tok=None, expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -207,6 +230,7 @@ class RestHelper: "PUT", path, json.dumps(content or {}).encode("utf8"), + custom_headers=custom_headers, ) assert ( diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 856aa8682f..2e2f94742e 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): prev_token = None found_event_ids = [] - encoded_key = urllib.parse.quote_plus("👍".encode("utf-8")) + encoded_key = urllib.parse.quote_plus("👍".encode()) for _ in range(20): from_token = "" if prev_token: diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py index 1ec6b05e5b..a76a6fef1e 100644 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -41,7 +41,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) resp = self.helper.send(self.room_id, tok=self.admin_user_tok) self.event_id = resp["event_id"] - self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id) + self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" def test_reason_str_and_score_int(self): data = {"reason": "this makes me sad", "score": -100} diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 95e7075841..2d6b49692e 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -310,7 +310,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): correctly decode it as the UTF-8 string, and use filename* in the response. """ - filename = parse.quote("\u2603".encode("utf8")).encode("ascii") + filename = parse.quote("\u2603".encode()).encode("ascii") channel = self._req( b"inline; filename*=utf-8''" + filename + self.test_image.extension ) |