diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 2f7090e554..a7c6e595b9 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -66,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
# Create a new group
channel = self.make_request(
"POST",
- "/create_group".encode("ascii"),
+ b"/create_group",
access_token=self.admin_user_tok,
content={"localpart": "test"},
)
@@ -129,9 +129,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
def _get_groups_user_is_in(self, access_token):
"""Returns the list of groups the user is in (given their access token)"""
- channel = self.make_request(
- "GET", "/joined_groups".encode("ascii"), access_token=access_token
- )
+ channel = self.make_request("GET", b"/joined_groups", access_token=access_token)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ee071c2477..17ec8bfd3b 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -535,7 +535,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
)
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+ self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
def _assert_peek(self, room_id, expect_code):
"""Assert that the admin user can (or cannot) peek into the room."""
@@ -599,7 +599,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
)
)
- self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+ self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
class RoomTestCase(unittest.HomeserverTestCase):
@@ -1280,7 +1280,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.public_room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True
)
- self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+ self.url = f"/_synapse/admin/v1/join/{self.public_room_id}"
def test_requester_is_no_admin(self):
"""
@@ -1420,7 +1420,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
private_room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=False
)
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ url = f"/_synapse/admin/v1/join/{private_room_id}"
body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
@@ -1463,7 +1463,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
# Join user to room.
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ url = f"/_synapse/admin/v1/join/{private_room_id}"
body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
@@ -1493,7 +1493,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
private_room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok, is_public=False
)
- url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+ url = f"/_synapse/admin/v1/join/{private_room_id}"
body = json.dumps({"user_id": self.second_user_id})
channel = self.make_request(
@@ -1633,7 +1633,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
content={},
access_token=self.admin_user_tok,
)
@@ -1660,7 +1660,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
content={},
access_token=self.admin_user_tok,
)
@@ -1686,7 +1686,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
content={"user_id": self.second_user_id},
access_token=self.admin_user_tok,
)
@@ -1720,7 +1720,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+ f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
content={},
access_token=self.admin_user_tok,
)
@@ -1753,7 +1753,6 @@ PURGE_TABLES = [
"room_memberships",
"room_stats_state",
"room_stats_current",
- "room_stats_historical",
"room_stats_earliest_token",
"rooms",
"stream_ordering_to_exterm",
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e1fe72fc5d..28dd47a28b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -16,17 +16,19 @@ from typing import Dict
from unittest.mock import Mock
from synapse.events import EventBase
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import Requester, StateMap
+from synapse.util.frozenutils import unfreeze
from tests import unittest
thread_local = threading.local()
-class ThirdPartyRulesTestModule:
+class LegacyThirdPartyRulesTestModule:
def __init__(self, config: Dict, module_api: ModuleApi):
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
@@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule:
return config
-def current_rules_module() -> ThirdPartyRulesTestModule:
- return thread_local.rules_module
+class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ super().__init__(config, module_api)
+
+ def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ):
+ return False
+
+
+class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ super().__init__(config, module_api)
+
+ async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ d = event.get_dict()
+ content = unfreeze(event.content)
+ content["foo"] = "bar"
+ d["content"] = content
+ return d
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
@@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def default_config(self):
- config = super().default_config()
- config["third_party_event_rules"] = {
- "module": __name__ + ".ThirdPartyRulesTestModule",
- "config": {},
- }
- return config
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ load_legacy_third_party_event_rules(hs)
+
+ return hs
def prepare(self, reactor, clock, homeserver):
# Create a user and room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey")
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+ # Some tests might prevent room creation on purpose.
+ try:
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+ except Exception:
+ pass
def test_third_party_rules(self):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
@@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
- return ev.type != "foo.bar.forbidden"
+ return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check)
- current_rules_module().check_event_allowed = callback
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
+ callback
+ ]
channel = self.make_request(
"PUT",
@@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state):
ev.content = {"x": "y"}
- return True
+ return True, None
- current_rules_module().check_event_allowed = check
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event
channel = self.make_request(
@@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"500", channel.result)
+ # check_event_allowed has some error handling, so it shouldn't 500 just because a
+ # module did something bad.
+ self.assertEqual(channel.code, 200, channel.result)
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["x"], "x")
def test_modify_event(self):
"""The module can return a modified version of the event"""
@@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
async def check(ev: EventBase, state):
d = ev.get_dict()
d["content"] = {"x": "y"}
- return d
+ return True, d
- current_rules_module().check_event_allowed = check
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event
channel = self.make_request(
@@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"msgtype": "m.text",
"body": d["content"]["body"].upper(),
}
- return d
+ return True, d
- current_rules_module().check_event_allowed = check
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# Send an event, then edit it.
channel = self.make_request(
@@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.assertEqual(ev["content"]["body"], "EDITED BODY")
def test_send_event(self):
- """Tests that the module can send an event into a room via the module api"""
+ """Tests that a module can send an event into a room via the module api"""
content = {
"msgtype": "m.text",
"body": "Hello!",
@@ -233,13 +270,60 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"content": content,
"sender": self.user_id,
}
- event = self.get_success(
- current_rules_module().module_api.create_and_send_event_into_room(
- event_dict
- )
- ) # type: EventBase
+ event: EventBase = self.get_success(
+ self.hs.get_module_api().create_and_send_event_into_room(event_dict)
+ )
self.assertEquals(event.sender, self.user_id)
self.assertEquals(event.room_id, self.room_id)
self.assertEquals(event.type, "m.room.message")
self.assertEquals(event.content, content)
+
+ @unittest.override_config(
+ {
+ "third_party_event_rules": {
+ "module": __name__ + ".LegacyChangeEvents",
+ "config": {},
+ }
+ }
+ )
+ def test_legacy_check_event_allowed(self):
+ """Tests that the wrapper for legacy check_event_allowed callbacks works
+ correctly.
+ """
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id,
+ {
+ "msgtype": "m.text",
+ "body": "Original body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+
+ event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+
+ self.assertIn("foo", channel.json_body["content"].keys())
+ self.assertEqual(channel.json_body["content"]["foo"], "bar")
+
+ @unittest.override_config(
+ {
+ "third_party_event_rules": {
+ "module": __name__ + ".LegacyDenyNewRooms",
+ "config": {},
+ }
+ }
+ )
+ def test_legacy_on_create_room(self):
+ """Tests that the wrapper for legacy on_create_room callbacks works
+ correctly.
+ """
+ self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 605b952316..7eba69642a 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -453,7 +453,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# stick the flows results in a dict by type
- flow_results = {} # type: Dict[str, Any]
+ flow_results: Dict[str, Any] = {}
for f in channel.json_body["flows"]:
flow_type = f["type"]
self.assertNotIn(
@@ -501,7 +501,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
p.close()
# there should be a link for each href
- returned_idps = [] # type: List[str]
+ returned_idps: List[str] = []
for link in p.links:
path, query = link.split("?", 1)
self.assertEqual(path, "pick_idp")
@@ -582,7 +582,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# ... and should have set a cookie including the redirect url
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
assert cookie_headers
- cookies = {} # type: Dict[str, str]
+ cookies: Dict[str, str] = {}
for h in cookie_headers:
key, value = h.split(";")[0].split("=", maxsplit=1)
cookies[key] = value
@@ -874,9 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result = jwt.encode(
- payload, secret, self.jwt_algorithm
- ) # type: Union[str, bytes]
+ result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
if isinstance(result, bytes):
return result.decode("ascii")
return result
@@ -1084,7 +1082,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
+ result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
if isinstance(result, bytes):
return result.decode("ascii")
return result
@@ -1272,7 +1270,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
# ... with a username_mapping_session cookie
- cookies = {} # type: Dict[str,str]
+ cookies: Dict[str, str] = {}
channel.extract_cookies(cookies)
self.assertIn("username_mapping_session", cookies)
session_id = cookies["username_mapping_session"]
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e94566ffd7..3df070c936 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1206,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/join",
content={"reason": reason},
access_token=self.second_tok,
)
@@ -1220,7 +1220,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/leave",
content={"reason": reason},
access_token=self.second_tok,
)
@@ -1234,7 +1234,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/kick",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.second_tok,
)
@@ -1248,7 +1248,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/ban",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
@@ -1260,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/unban",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
@@ -1272,7 +1272,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/invite",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
@@ -1291,7 +1291,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/leave",
content={"reason": reason},
access_token=self.second_tok,
)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 69798e95c3..fc2d35596e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -19,7 +19,7 @@ import json
import re
import time
import urllib.parse
-from typing import Any, Dict, Mapping, MutableMapping, Optional
+from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
from unittest.mock import patch
import attr
@@ -53,6 +53,9 @@ class RestHelper:
tok: str = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
) -> str:
"""
Create a room.
@@ -87,6 +90,7 @@ class RestHelper:
"POST",
path,
json.dumps(content).encode("utf8"),
+ custom_headers=custom_headers,
)
assert channel.result["code"] == b"%d" % expect_code, channel.result
@@ -175,14 +179,30 @@ class RestHelper:
self.auth_user_id = temp_id
- def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
+ def send(
+ self,
+ room_id,
+ body=None,
+ txn_id=None,
+ tok=None,
+ expect_code=200,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ ):
if body is None:
body = "body_text_here"
content = {"msgtype": "m.text", "body": body}
return self.send_event(
- room_id, "m.room.message", content, txn_id, tok, expect_code
+ room_id,
+ "m.room.message",
+ content,
+ txn_id,
+ tok,
+ expect_code,
+ custom_headers=custom_headers,
)
def send_event(
@@ -193,6 +213,9 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -207,6 +230,7 @@ class RestHelper:
"PUT",
path,
json.dumps(content or {}).encode("utf8"),
+ custom_headers=custom_headers,
)
assert (
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 856aa8682f..2e2f94742e 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -273,7 +273,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
prev_token = None
found_event_ids = []
- encoded_key = urllib.parse.quote_plus("👍".encode("utf-8"))
+ encoded_key = urllib.parse.quote_plus("👍".encode())
for _ in range(20):
from_token = ""
if prev_token:
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py
index 1ec6b05e5b..a76a6fef1e 100644
--- a/tests/rest/client/v2_alpha/test_report_event.py
+++ b/tests/rest/client/v2_alpha/test_report_event.py
@@ -41,7 +41,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok)
resp = self.helper.send(self.room_id, tok=self.admin_user_tok)
self.event_id = resp["event_id"]
- self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id)
+ self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
def test_reason_str_and_score_int(self):
data = {"reason": "this makes me sad", "score": -100}
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 95e7075841..2d6b49692e 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -310,7 +310,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
correctly decode it as the UTF-8 string, and use filename* in the
response.
"""
- filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
+ filename = parse.quote("\u2603".encode()).encode("ascii")
channel = self._req(
b"inline; filename*=utf-8''" + filename + self.test_image.extension
)
|