diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a43a137273..c1a7fb2f8a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
import re
from email.parser import Parser
+from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
@@ -95,10 +95,8 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
"""
body = {"type": "m.login.password", "user": username, "password": password}
- channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
- )
- self.assertEqual(channel.code, 403, channel.result)
+ channel = self.make_request("POST", "/_matrix/client/r0/login", body)
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
@@ -347,7 +345,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button
@@ -362,7 +360,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
content_is_form=True,
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -390,7 +388,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
new_password: str,
session_id: str,
client_secret: str,
- expected_code: int = 200,
+ expected_code: int = HTTPStatus.OK,
) -> None:
channel = self.make_request(
"POST",
@@ -479,20 +477,18 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertEqual(memberships[0].room_id, room_id, memberships)
def deactivate(self, user_id: str, tok: str) -> None:
- request_data = json.dumps(
- {
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "test",
- },
- "erase": False,
- }
- )
+ request_data = {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "test",
+ },
+ "erase": False,
+ }
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.json_body)
class WhoamiTestCase(unittest.HomeserverTestCase):
@@ -645,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_add_email_no_at(self) -> None:
self._request_token_invalid_email(
"address-without-at.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_two_at(self) -> None:
self._request_token_invalid_email(
"foo@foo@test.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_bad_format(self) -> None:
self._request_token_invalid_email(
"user@bad.example.net@good.example.com",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
@@ -715,7 +711,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -725,7 +723,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self) -> None:
@@ -747,7 +745,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -756,7 +754,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self) -> None:
@@ -781,7 +779,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
@@ -817,7 +817,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -827,7 +829,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_no_valid_token(self) -> None:
@@ -852,7 +854,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -862,7 +866,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None})
@@ -872,7 +876,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/good/site",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@@ -884,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@@ -895,7 +899,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="file:///host/path",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
@@ -907,28 +911,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link=None,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.com/some/good/page",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.org/some/also/good/page",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://bad.example.org/some/bad/page",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": []})
@@ -940,7 +944,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/page",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
def _request_token(
@@ -948,8 +952,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
email: str,
client_secret: str,
next_link: Optional[str] = None,
- expect_code: int = 200,
- ) -> str:
+ expect_code: int = HTTPStatus.OK,
+ ) -> Optional[str]:
"""Request a validation token to add an email address to a user's account
Args:
@@ -959,7 +963,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
expect_code: Expected return code of the call
Returns:
- The ID of the new threepid validation session
+ The ID of the new threepid validation session, or None if the response
+ did not contain a session ID.
"""
body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
if next_link:
@@ -992,16 +997,18 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(expected_errcode, channel.json_body["errcode"])
- self.assertEqual(expected_error, channel.json_body["error"])
+ self.assertIn(expected_error, channel.json_body["error"])
def _validate_token(self, link: str) -> None:
# Remove the host
path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -1051,7 +1058,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -1060,7 +1067,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@@ -1091,7 +1098,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that not providing any MXID raises an error."""
self._test_status(
users=None,
- expected_status_code=400,
+ expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.MISSING_PARAM,
)
@@ -1099,7 +1106,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that providing an invalid MXID raises an error."""
self._test_status(
users=["bad:test"],
- expected_status_code=400,
+ expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.INVALID_PARAM,
)
@@ -1285,7 +1292,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
def _test_status(
self,
users: Optional[List[str]],
- expected_status_code: int = 200,
+ expected_status_code: int = HTTPStatus.OK,
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None,
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index aca03afd0e..7a88aa2cda 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
+from synapse.appservice import ApplicationService
from synapse.rest import admin
from synapse.rest.client import directory, login, room
from synapse.server import HomeServer
@@ -96,8 +96,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# We use deliberately a localpart under the length threshold so
# that we can make sure that the check is done on the whole alias.
- data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
- request_data = json.dumps(data)
+ request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
@@ -109,8 +108,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# Check with an alias of allowed length. There should already be
# a test that ensures it works in test_register.py, but let's be
# as cautious as possible here.
- data = {"room_alias_name": random_string(5)}
- request_data = json.dumps(data)
+ request_data = {"room_alias_name": random_string(5)}
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
@@ -129,6 +127,38 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ def test_deleting_alias_via_directory_appservice(self) -> None:
+ user_id = "@as:test"
+ as_token = "i_am_an_app_service"
+
+ appservice = ApplicationService(
+ as_token,
+ id="1234",
+ namespaces={"aliases": [{"regex": "#asns-*", "exclusive": True}]},
+ sender=user_id,
+ )
+ self.hs.get_datastores().main.services_cache.append(appservice)
+
+ # Add an alias for the room, as the appservice
+ alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string()
+ request_data = {"room_id": self.room_id}
+
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ request_data,
+ access_token=as_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ # Then try to remove the alias, as the appservice
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=as_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
def test_deleting_nonexistant_alias(self) -> None:
# Check that no alias exists
alias = "#potato:test"
@@ -159,8 +189,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.hs.hostname,
)
- data = {"aliases": [self.random_alias(alias_length)]}
- request_data = json.dumps(data)
+ request_data = {"aliases": [self.random_alias(alias_length)]}
channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
@@ -172,8 +201,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
) -> str:
alias = self.random_alias(alias_length)
url = "/_matrix/client/r0/directory/room/%s" % alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
@@ -181,6 +209,19 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, expected_code, channel.result)
return alias
+ def test_invalid_alias(self) -> None:
+ alias = "#potato"
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/r0/directory/room/{alias}",
+ access_token=self.user_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertIn("error", channel.json_body, channel.json_body)
+ self.assertEqual(
+ channel.json_body["errcode"], "M_INVALID_PARAM", channel.json_body
+ )
+
def random_alias(self, length: int) -> str:
return RoomAlias(random_string(length), self.hs.hostname).to_string()
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 823e8ab8c4..afc8d641be 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON,
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0)
@@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.EXAMPLE_FILTER_JSON,
)
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_add_filter_non_local_user(self) -> None:
@@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
)
self.hs.is_mine = _is_mine
- self.assertEqual(channel.result["code"], b"403")
+ self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self) -> None:
@@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self) -> None:
@@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"404")
+ self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
@@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
# No ID also returns an invalid_id error
def test_get_filter_no_id(self) -> None:
@@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase):
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
- self.assertEqual(channel.result["code"], b"400")
+ self.assertEqual(channel.code, 400)
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 299b9d21e2..b0c8215744 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -26,7 +25,6 @@ from tests import unittest
class IdentityTestCase(unittest.HomeserverTestCase):
-
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -34,7 +32,6 @@ class IdentityTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
config = self.default_config()
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
@@ -51,12 +48,12 @@ class IdentityTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
room_id = channel.json_body["room_id"]
- params = {
+ request_data = {
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
+ "id_access_token": tok,
}
- request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
channel = self.make_request(
b"POST", request_url, request_data, access_token=tok
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f4ea1209d9..e2a4d98275 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import time
import urllib.parse
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -41,7 +40,7 @@ from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
- import jwt
+ from authlib.jose import jwk, jwt
HAS_JWT = True
except ImportError:
@@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config(
{
@@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config(
{
@@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", LOGIN_URL, params)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
@@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_soft_logout(self) -> None:
@@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we shouldn't be able to make requests without an access token
channel = self.make_request(b"GET", TEST_URL)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
# log in as normal
@@ -354,7 +353,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# Now try to hard logout this session
channel = self.make_request(b"POST", "/logout", access_token=access_token)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
@@ -380,7 +379,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# Now try to hard log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_with_overly_long_device_id_fails(self) -> None:
self.register_user("mickey", "cheese")
@@ -399,7 +398,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"/_matrix/client/v3/login",
- json.dumps(body).encode("utf8"),
+ body,
custom_headers=None,
)
@@ -841,7 +840,7 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertIn(b"SSO account deactivated", channel.result["body"])
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -866,11 +865,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
- # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
- if isinstance(result, bytes):
- return result.decode("ascii")
- return result
+ header = {"alg": self.jwt_algorithm}
+ result: bytes = jwt.encode(header, payload, secret)
+ return result.decode("ascii")
def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
@@ -880,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret")
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -899,25 +896,26 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Signature has expired"
+ channel.json_body["error"],
+ "JWT validation failed: expired_token: The token is expired",
)
def test_login_jwt_not_before(self) -> None:
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
- "JWT validation failed: The token is not yet valid (nbf)",
+ "JWT validation failed: invalid_token: The token is not valid yet",
)
def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
@@ -926,30 +924,31 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the issuer claim."""
# A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Invalid issuer"
+ channel.json_body["error"],
+ 'JWT validation failed: invalid_claim: Invalid claim "iss"',
)
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
- 'JWT validation failed: Token is missing the "iss" claim',
+ 'JWT validation failed: missing_claim: Missing "iss" claim',
)
def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
@@ -957,52 +956,54 @@ class JWTTestCase(unittest.HomeserverTestCase):
"""Test validating the audience claim."""
# A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Invalid audience"
+ channel.json_body["error"],
+ 'JWT validation failed: invalid_claim: Invalid claim "aud"',
)
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
- 'JWT validation failed: Token is missing the "aud" claim',
+ 'JWT validation failed: missing_claim: Missing "aud" claim',
)
def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
- channel.json_body["error"], "JWT validation failed: Invalid audience"
+ channel.json_body["error"],
+ 'JWT validation failed: invalid_claim: Invalid claim "aud"',
)
def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -1010,7 +1011,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key.
-@skip_unless(HAS_JWT, "requires jwt")
+@skip_unless(HAS_JWT, "requires authlib")
class JWTPubKeyTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
@@ -1071,11 +1072,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return config
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
- # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
- if isinstance(result, bytes):
- return result.decode("ascii")
- return result
+ header = {"alg": "RS256"}
+ if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"):
+ secret = jwk.dumps(secret, kty="RSA")
+ result: bytes = jwt.encode(header, payload, secret)
+ return result.decode("ascii")
def jwt_login(self, *args: Any) -> FakeChannel:
params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
@@ -1084,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"})
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
@@ -1150,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login"""
@@ -1164,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token"""
@@ -1178,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.service.token
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token"""
@@ -1192,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice
@@ -1206,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
@skip_unless(HAS_OIDC, "requires OIDC")
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
new file mode 100644
index 0000000000..a9da00665e
--- /dev/null
+++ b/tests/rest/client/test_models.py
@@ -0,0 +1,53 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+from pydantic import ValidationError
+
+from synapse.rest.client.models import EmailRequestTokenBody
+
+
+class EmailRequestTokenBodyTestCase(unittest.TestCase):
+ base_request = {
+ "client_secret": "hunter2",
+ "email": "alice@wonderland.com",
+ "send_attempt": 1,
+ }
+
+ def test_token_required_if_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ }
+ )
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": None,
+ }
+ )
+
+ def test_token_typechecked_when_id_server_provided(self) -> None:
+ with self.assertRaises(ValidationError):
+ EmailRequestTokenBody.parse_obj(
+ {
+ **self.base_request,
+ "id_server": "identity.wonderland.com",
+ "id_access_token": 1337,
+ }
+ )
diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3a74d2e96c..e19d21d6ee 100644
--- a/tests/rest/client/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -89,7 +88,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_too_short(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "shorty"})
+ request_data = {"username": "kermit", "password": "shorty"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -100,7 +99,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_digit(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
+ request_data = {"username": "kermit", "password": "longerpassword"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -111,7 +110,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_symbol(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
+ request_data = {"username": "kermit", "password": "l0ngerpassword"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -122,7 +121,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_uppercase(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
+ request_data = {"username": "kermit", "password": "l0ngerpassword!"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -133,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_lowercase(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
+ request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"}
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -144,7 +143,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_compliant(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
+ request_data = {"username": "kermit", "password": "L0ngerpassword!"}
channel = self.make_request("POST", self.register_url, request_data)
# Getting a 401 here means the password has passed validation and the server has
@@ -161,16 +160,14 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
user_id = self.register_user("kermit", compliant_password)
tok = self.login("kermit", compliant_password)
- request_data = json.dumps(
- {
- "new_password": not_compliant_password,
- "auth": {
- "password": compliant_password,
- "type": LoginType.PASSWORD,
- "user": user_id,
- },
- }
- )
+ request_data = {
+ "new_password": not_compliant_password,
+ "auth": {
+ "password": compliant_password,
+ "type": LoginType.PASSWORD,
+ "user": user_id,
+ },
+ }
channel = self.make_request(
"POST",
"/_matrix/client/r0/account/password",
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 77c3ced42e..8de5a342ae 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -13,6 +13,8 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
+import urllib.parse
+from http import HTTPStatus
from typing import Any, Dict, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -49,6 +51,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
res = self._get_displayname()
self.assertEqual(res, "owner")
+ def test_get_displayname_rejects_bad_username(self) -> None:
+ channel = self.make_request(
+ "GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname"
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+
def test_set_displayname(self) -> None:
channel = self.make_request(
"PUT",
@@ -145,18 +153,22 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
- def _get_displayname(self, name: Optional[str] = None) -> str:
+ def _get_displayname(self, name: Optional[str] = None) -> Optional[str]:
channel = self.make_request(
"GET", "/profile/%s/displayname" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
- return channel.json_body["displayname"]
+ # FIXME: If a user has no displayname set, Synapse returns 200 and omits a
+ # displayname from the response. This contradicts the spec, see #13137.
+ return channel.json_body.get("displayname")
- def _get_avatar_url(self, name: Optional[str] = None) -> str:
+ def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]:
channel = self.make_request(
"GET", "/profile/%s/avatar_url" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
+ # FIXME: If a user has no avatar set, Synapse returns 200 and omits an
+ # avatar_url from the response. This contradicts the spec, see #13137.
return channel.json_body.get("avatar_url")
@unittest.override_config({"max_avatar_size": 50})
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 7401b5e0c0..be4c67d68e 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -76,12 +76,12 @@ class RedactionsTestCase(HomeserverTestCase):
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
channel = self.make_request("POST", path, content={}, access_token=access_token)
- self.assertEqual(int(channel.result["code"]), expect_code)
+ self.assertEqual(channel.code, expect_code)
return channel.json_body
def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
- self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.code, 200)
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index afb08b2736..b781875d52 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
-import json
import os
from typing import Any, Dict, List, Tuple
@@ -62,15 +61,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastores().main.services_cache.append(appservice)
- request_data = json.dumps(
- {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
- )
+ request_data = {
+ "username": "as_user_kermit",
+ "type": APP_SERVICE_REGISTRATION_TYPE,
+ }
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body)
@@ -85,49 +85,46 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastores().main.services_cache.append(appservice)
- request_data = json.dumps({"username": "as_user_kermit"})
+ request_data = {"username": "as_user_kermit"}
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists
- request_data = json.dumps(
- {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
- )
+ request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
def test_POST_bad_password(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": 666})
+ request_data = {"username": "kermit", "password": 666}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self) -> None:
- request_data = json.dumps({"username": 777, "password": "monkey"})
+ request_data = {"username": 777, "password": "monkey"}
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"400", channel.result)
+ self.assertEqual(channel.code, 400, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self) -> None:
user_id = "@kermit:test"
device_id = "frogfone"
- params = {
+ request_data = {
"username": "kermit",
"password": "monkey",
"device_id": device_id,
"auth": {"type": LoginType.DUMMY},
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
det_data = {
@@ -135,17 +132,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
@override_config({"enable_registration": False})
def test_POST_disabled_registration(self) -> None:
- request_data = json.dumps({"username": "kermit", "password": "monkey"})
+ request_data = {"username": "kermit", "password": "monkey"}
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Registration has been disabled")
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -156,7 +153,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
def test_POST_disabled_guest_registration(self) -> None:
@@ -164,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["error"], "Guest access is disabled")
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
@@ -174,40 +171,39 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(b"POST", url, b"{}")
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None:
for i in range(0, 6):
- params = {
+ request_data = {
"username": "kermit" + str(i),
"password": "monkey",
"device_id": "frogfone",
"auth": {"type": LoginType.DUMMY},
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
@override_config({"registration_requires_token": True})
def test_POST_registration_requires_token(self) -> None:
@@ -234,8 +230,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
# Request without auth to get flows and session
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one
# flow would be a subset of another flow.
@@ -251,9 +247,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
- request_data = json.dumps(params)
- channel = self.make_request(b"POST", self.url, request_data)
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -262,14 +257,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"type": LoginType.DUMMY,
"session": session,
}
- request_data = json.dumps(params)
- channel = self.make_request(b"POST", self.url, request_data)
+ channel = self.make_request(b"POST", self.url, params)
det_data = {
"user_id": f"@{username}:{self.hs.hostname}",
"home_server": self.hs.hostname,
"device_id": device_id,
}
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
# Check the `completed` counter has been incremented and pending is 0
@@ -290,7 +284,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"password": "monkey",
}
# Request without auth to get session
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
session = channel.json_body["session"]
# Test with token param missing (invalid)
@@ -298,22 +292,22 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"type": LoginType.REGISTRATION_TOKEN,
"session": session,
}
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with non-string (invalid)
params["auth"]["token"] = 1234
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with unknown token (invalid)
params["auth"]["token"] = "1234"
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -337,9 +331,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
# Do 2 requests without auth to get two session IDs
- channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ channel1 = self.make_request(b"POST", self.url, params1)
session1 = channel1.json_body["session"]
- channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ channel2 = self.make_request(b"POST", self.url, params2)
session2 = channel2.json_body["session"]
# Use token with session1 and check `pending` is 1
@@ -348,9 +342,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session1,
}
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
# Repeat request to make sure pending isn't increased again
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
pending = self.get_success(
store.db_pool.simple_select_one_onecol(
"registration_tokens",
@@ -366,14 +360,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session2,
}
- channel = self.make_request(b"POST", self.url, json.dumps(params2))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params2)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
# Check pending=0 and completed=1
res = self.get_success(
store.db_pool.simple_select_one(
@@ -386,8 +380,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(res["completed"], 1)
# Check auth still fails when using token with session2
- channel = self.make_request(b"POST", self.url, json.dumps(params2))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params2)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -411,7 +405,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
params: JsonDict = {"username": "kermit", "password": "monkey"}
# Request without auth to get session
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
session = channel.json_body["session"]
# Check authentication fails with expired token
@@ -420,8 +414,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
- channel = self.make_request(b"POST", self.url, json.dumps(params))
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ channel = self.make_request(b"POST", self.url, params)
+ self.assertEqual(channel.code, 401, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -435,7 +429,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
# Check authentication succeeds
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -460,9 +454,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Do 2 requests without auth to get two session IDs
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
- channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ channel1 = self.make_request(b"POST", self.url, params1)
session1 = channel1.json_body["session"]
- channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ channel2 = self.make_request(b"POST", self.url, params2)
session2 = channel2.json_body["session"]
# Use token with both sessions
@@ -471,18 +465,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session1,
}
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
params2["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session2,
}
- self.make_request(b"POST", self.url, json.dumps(params2))
+ self.make_request(b"POST", self.url, params2)
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
- self.make_request(b"POST", self.url, json.dumps(params1))
+ self.make_request(b"POST", self.url, params1)
# Check `result` of registration token stage for session1 is `True`
result1 = self.get_success(
@@ -550,7 +544,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Do request without auth to get a session ID
params: JsonDict = {"username": "kermit", "password": "monkey"}
- channel = self.make_request(b"POST", self.url, json.dumps(params))
+ channel = self.make_request(b"POST", self.url, params)
session = channel.json_body["session"]
# Use token
@@ -559,7 +553,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
- self.make_request(b"POST", self.url, json.dumps(params))
+ self.make_request(b"POST", self.url, params)
# Delete token
self.get_success(
@@ -576,7 +570,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_advertised_flows(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# with the stock config, we only expect the dummy flow
@@ -592,14 +586,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"require_at_registration": True,
},
"account_threepid_delegates": {
- "email": "https://id_server",
"msisdn": "https://id_server",
},
+ "email": {"notif_from": "Synapse <synapse@example.com>"},
}
)
def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
self.assertCountEqual(
@@ -631,7 +625,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
def test_advertised_flows_no_msisdn_email_required(self) -> None:
channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.result["code"], b"401", channel.result)
+ self.assertEqual(channel.code, 401, msg=channel.result)
flows = channel.json_body["flows"]
# with the stock config, we expect all four combinations of 3pid
@@ -803,13 +797,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
@@ -827,15 +821,14 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
- params = {"user_id": user_id}
- request_data = json.dumps(params)
+ request_data = {"user_id": user_id}
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_manual_expire(self) -> None:
user_id = self.register_user("kermit", "monkey")
@@ -845,19 +838,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
- params = {
+ request_data = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
)
@@ -870,25 +862,24 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
- params = {
+ request_data = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
- request_data = json.dumps(params)
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Try to log the user out
channel = self.make_request(b"POST", "/logout", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Log the user in again (allowed for expired accounts)
tok = self.login("kermit", "monkey")
# Try to log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@@ -963,7 +954,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -981,7 +972,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Move 1 day forward. Try to renew with the same token again.
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -1001,14 +992,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
def test_renewal_invalid_token(self) -> None:
# Hit the renewal endpoint with an invalid token and check that it behaves as
# expected, i.e. that it responds with 404 Not Found and the correct HTML.
url = "/_matrix/client/unstable/account_validity/renew?token=123"
channel = self.make_request(b"GET", url)
- self.assertEqual(channel.result["code"], b"404", channel.result)
+ self.assertEqual(channel.code, 404, msg=channel.result)
# Check that we're getting HTML back.
content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -1032,7 +1023,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -1041,16 +1032,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
(user_id, tok) = self.create_user()
- request_data = json.dumps(
- {
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "monkey",
- },
- "erase": False,
- }
- )
+ request_data = {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "monkey",
+ },
+ "erase": False,
+ }
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
@@ -1107,7 +1096,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(len(self.email_attempts), 1)
@@ -1187,7 +1176,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], True)
def test_GET_token_invalid(self) -> None:
@@ -1196,7 +1185,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["valid"], False)
@override_config(
@@ -1212,10 +1201,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
)
if i == 5:
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, msg=channel.result)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
@@ -1223,4 +1212,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
b"GET",
f"{self.url}?token={token}",
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 62e4db23ef..651f4f415d 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -728,6 +728,7 @@ class RelationsTestCase(BaseRelationsTestCase):
class RelationPaginationTestCase(BaseRelationsTestCase):
+ @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
def test_basic_paginate_relations(self) -> None:
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -799,7 +800,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
)
expected_event_ids.append(channel.json_body["event_id"])
- prev_token = ""
+ prev_token: Optional[str] = ""
found_event_ids: List[str] = []
for _ in range(20):
from_token = ""
@@ -998,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
def test_annotation_to_annotation(self) -> None:
"""Any relation to an annotation should be ignored."""
@@ -1034,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations,
)
- self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
def test_thread(self) -> None:
"""
@@ -1059,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
participated, bundled_aggregations.get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
+ self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
@@ -1071,28 +1073,28 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user2_id,
"type": "m.room.test",
},
- bundled_aggregations.get("latest_event"),
+ bundled_aggregations["latest_event"],
)
return assert_thread
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
- self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
# Note that this re-uses some cached values, so the total number of
# queries is much smaller.
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+ RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
)
# A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations(
- RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+ RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
)
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
@@ -1111,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter.
+ self.assertIn("latest_event", bundled_aggregations)
self.assert_dict(
{
"content": {
@@ -1123,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user_id,
"type": "m.room.test",
},
- bundled_aggregations.get("latest_event"),
+ bundled_aggregations["latest_event"],
)
# Check the unsigned field on the latest event.
self.assert_dict(
@@ -1139,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
def test_nested_thread(self) -> None:
"""
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 20a259fc43..7cb1017a4a 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
-
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -77,11 +75,6 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
- "POST",
- self.report_path,
- json.dumps(data),
- access_token=self.other_user_tok,
- )
- self.assertEqual(
- response_status, int(channel.result["code"]), msg=channel.result["body"]
+ "POST", self.report_path, data, access_token=self.other_user_tok
)
+ self.assertEqual(response_status, channel.code, msg=channel.result["body"])
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c113354..9c8c1889d3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from synapse.visibility import filter_events_for_client
@@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
message_handler.get_room_data(
- self.user_id, room_id, EventTypes.Create, state_key=""
+ create_requester(self.user_id), room_id, EventTypes.Create, state_key=""
)
)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index f523d89b8f..c7eb88d33f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,10 +18,14 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Any, Dict, Iterable, List, Optional
+from http import HTTPStatus
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from unittest.mock import Mock, call
from urllib import parse as urlparse
+from parameterized import param, parameterized
+from typing_extensions import Literal
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -30,7 +34,9 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
Membership,
+ PublicRoomsFilterFields,
RelationTypes,
+ RoomTypes,
)
from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
@@ -42,6 +48,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
from tests.test_utils import make_awaitable
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -98,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request(
"PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# set topic for public room
channel = self.make_request(
@@ -106,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase):
("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
b'{"topic":"Public Room Topic"}',
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# auth as user_id now
self.helper.auth_user_id = self.user_id
@@ -128,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content,
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room not joined (no state), expect 403
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# send message in created room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# send message in created room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", send_msg_path(), msg_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_topic_perms(self) -> None:
topic_content = b'{"topic":"My Topic Name"}'
@@ -159,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request(
"PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room not joined, expect 403
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", topic_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set topic in created PRIVATE room and invited, expect 403
self.helper.invite(
room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
)
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# get topic in created PRIVATE room and invited, expect 403
channel = self.make_request("GET", topic_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set/get topic in created PRIVATE room and joined, expect 200
self.helper.join(room=self.created_rmid, user=self.user_id)
@@ -188,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase):
# Only room ops can set topic by default
self.helper.auth_user_id = self.rmcreator_id
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.helper.auth_user_id = self.user_id
channel = self.make_request("GET", topic_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
# set/get topic in created PRIVATE room and left, expect 403
self.helper.leave(room=self.created_rmid, user=self.user_id)
channel = self.make_request("PUT", topic_path, topic_content)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", topic_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# get topic in PUBLIC room, not joined, expect 403
channel = self.make_request(
"GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
# set topic in PUBLIC room, not joined, expect 403
channel = self.make_request(
@@ -214,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase):
"/rooms/%s/state/m.room.topic" % self.created_public_rmid,
topic_content,
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def _test_get_membership(
self, room: str, members: Iterable = frozenset(), expect_code: int = 200
@@ -303,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
self.helper.change_membership(
room=room,
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
def test_joined_permissions(self) -> None:
@@ -336,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# set left of other, expect 403
@@ -345,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# set left of self, expect 200
@@ -365,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=usr,
membership=Membership.INVITE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
self.helper.change_membership(
@@ -373,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=usr,
membership=Membership.JOIN,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# It is always valid to LEAVE if you've already left (currently.)
@@ -382,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
- expect_code=403,
+ expect_code=HTTPStatus.FORBIDDEN,
)
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
@@ -399,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.BAN,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.FORBIDDEN,
)
@@ -409,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.BAN,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
# from ban to invite: Must never happen.
@@ -418,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.INVITE,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -428,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase):
src=other,
targ=other,
membership=Membership.JOIN,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -438,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.BAN,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
# from ban to knock: Must never happen.
@@ -447,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.KNOCK,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.BAD_STATE,
)
@@ -457,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase):
src=self.user_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=403, # expect failure
+ expect_code=HTTPStatus.FORBIDDEN, # expect failure
expect_errcode=Codes.FORBIDDEN,
)
@@ -467,10 +474,53 @@ class RoomPermissionsTestCase(RoomBase):
src=self.rmcreator_id,
targ=other,
membership=Membership.LEAVE,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
+class RoomStateTestCase(RoomBase):
+ """Tests /rooms/$room_id/state."""
+
+ user_id = "@sid1:red"
+
+ def test_get_state_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state" % room_id,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertCountEqual(
+ [state_event["type"] for state_event in channel.json_list],
+ {
+ "m.room.create",
+ "m.room.power_levels",
+ "m.room.join_rules",
+ "m.room.member",
+ "m.room.history_visibility",
+ },
+ )
+
+ def test_get_state_event_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state/$event_type` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id),
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertEqual(channel.json_body, {"membership": "join"})
+
+
class RoomsMemberListTestCase(RoomBase):
"""Tests /rooms/$room_id/members/list REST events."""
@@ -481,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase):
def test_get_member_list(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self) -> None:
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self) -> None:
room_id = self.helper.create_room_as("@some_other_guy:red")
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_with_at_token(self) -> None:
"""
@@ -501,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase):
# first sync to get an at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check that permission is denied for @sid1:red to get the
@@ -510,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase):
"GET",
f"/rooms/{room_id}/members?at={sync_token}",
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member(self) -> None:
"""
@@ -523,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase):
# check that the user can see the member list to start with
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# ban the user
self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban")
# check the user can no longer see the member list
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member_with_at_token(self) -> None:
"""
@@ -544,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase):
# sync to get an at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check that the user can see the member list to start with
channel = self.make_request(
"GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# ban the user (Note: the user is actually allowed to see this event and
# state so that they know they're banned!)
@@ -563,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase):
# now, with the original user, sync again to get a new at token
channel = self.make_request("GET", "/sync")
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
sync_token = channel.json_body["next_batch"]
# check the user can no longer see the updated member list
channel = self.make_request(
"GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self) -> None:
room_creator = "@some_other_guy:red"
@@ -579,17 +629,73 @@ class RoomsMemberListTestCase(RoomBase):
self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
# can't see list if you're just invited.
channel = self.make_request("GET", room_path)
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.helper.join(room=room_id, user=self.user_id)
# can see list now joined
channel = self.make_request("GET", room_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.helper.leave(room=room_id, user=self.user_id)
# can see old list once left
channel = self.make_request("GET", room_path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+
+ def test_get_member_list_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members" % room_id,
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
+
+ def test_get_member_list_with_at_token_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # first sync to get an at token
+ channel = self.make_request("GET", "/sync")
+ self.assertEqual(HTTPStatus.OK, channel.code)
+ sync_token = channel.json_body["next_batch"]
+
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_with_at_token_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members?at=%s" % (room_id, sync_token),
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
class RoomsCreateTestCase(RoomBase):
@@ -601,19 +707,34 @@ class RoomsCreateTestCase(RoomBase):
# POST with no config keys, expect new room id
channel = self.make_request("POST", "/createRoom", "{}")
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(44, channel.resource_usage.db_txn_count)
+
+ def test_post_room_initial_state(self) -> None:
+ # POST with initial_state config key, expect new room id
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ b'{"initial_state":[{"type": "m.bridge", "content": {}}]}',
+ )
+
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertTrue("room_id" in channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(50, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self) -> None:
# POST with custom config keys, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self) -> None:
@@ -621,16 +742,16 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self) -> None:
# POST with invalid content / paths, expect 400
channel = self.make_request("POST", "/createRoom", b'{"visibili')
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
channel = self.make_request("POST", "/createRoom", b'["hello"]')
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
def test_post_room_invitees_invalid_mxid(self) -> None:
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
@@ -638,7 +759,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request(
"POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
)
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
def test_post_room_invitees_ratelimit(self) -> None:
@@ -649,20 +770,18 @@ class RoomsCreateTestCase(RoomBase):
# Build the request's content. We use local MXIDs because invites over federation
# are more difficult to mock.
- content = json.dumps(
- {
- "invite": [
- "@alice1:red",
- "@alice2:red",
- "@alice3:red",
- "@alice4:red",
- ]
- }
- ).encode("utf8")
+ content = {
+ "invite": [
+ "@alice1:red",
+ "@alice2:red",
+ "@alice3:red",
+ "@alice4:red",
+ ]
+ }
# Test that the invites are correctly ratelimited.
channel = self.make_request("POST", "/createRoom", content)
- self.assertEqual(400, channel.code)
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code)
self.assertEqual(
"Cannot invite so many users at once",
channel.json_body["error"],
@@ -675,11 +794,13 @@ class RoomsCreateTestCase(RoomBase):
# Test that the invites aren't ratelimited anymore.
channel = self.make_request("POST", "/createRoom", content)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
- def test_spam_checker_may_join_room(self) -> None:
+ def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room.
+
+ In this test, we use the deprecated API in which callbacks return a bool.
"""
async def user_may_join_room(
@@ -697,10 +818,55 @@ class RoomsCreateTestCase(RoomBase):
"/createRoom",
{},
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
+ def test_spam_checker_may_join_room(self) -> None:
+ """Tests that the user_may_join_room spam checker callback is correctly bypassed
+ when creating a new room.
+
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+ """
+
+ async def user_may_join_room_codes(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> Codes:
+ return Codes.CONSENT_NOT_GIVEN
+
+ join_mock = Mock(side_effect=user_may_join_room_codes)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+ self.assertEqual(join_mock.call_count, 0)
+
+ # Now change the return value of the callback to deny any join. Since we're
+ # creating the room, despite the return value, we should be able to join.
+ async def user_may_join_room_tuple(
+ mxid: str,
+ room_id: str,
+ is_invite: bool,
+ ) -> Tuple[Codes, dict]:
+ return Codes.INCOMPATIBLE_ROOM_VERSION, {}
+
+ join_mock.side_effect = user_may_join_room_tuple
+
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ {},
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ self.assertEqual(join_mock.call_count, 0)
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@@ -715,54 +881,68 @@ class RoomTopicTestCase(RoomBase):
def test_invalid_puts(self) -> None:
# missing keys or invalid json
channel = self.make_request("PUT", self.path, "{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, '{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request(
"PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, "text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", self.path, "")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# valid key, wrong type
content = '{"topic":["Topic name"]}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_topic(self) -> None:
# nothing should be there
channel = self.make_request("GET", self.path)
- self.assertEqual(404, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"])
# valid put
content = '{"topic":"Topic name"}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# valid get
channel = self.make_request("GET", self.path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self) -> None:
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
channel = self.make_request("PUT", self.path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# valid get
channel = self.make_request("GET", self.path)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
@@ -778,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase):
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
channel = self.make_request("PUT", path, "{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, '{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, '{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, "text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, "")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# valid keys, wrong types
content = '{"membership":["%s","%s","%s"]}' % (
@@ -802,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase):
Membership.LEAVE,
)
channel = self.make_request("PUT", path, content.encode("ascii"))
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_members_self(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % (
@@ -813,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid join message (NOOP since we made the room)
content = '{"membership":"%s"}' % Membership.JOIN
channel = self.make_request("PUT", path, content.encode("ascii"))
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
self.assertEqual(expected_response, channel.json_body)
@@ -831,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase):
# valid invite message
content = '{"membership":"%s"}' % Membership.INVITE
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self) -> None:
@@ -850,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase):
"Join us!",
)
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, content=b"")
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
@@ -911,9 +1105,11 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
- def test_spam_checker_may_join_room(self) -> None:
+ def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
+
+ This test uses the deprecated API, in which callbacks return booleans.
"""
# Register a dummy callback. Make it allow all room joins for now.
@@ -926,6 +1122,8 @@ class RoomJoinTestCase(RoomBase):
) -> bool:
return return_value
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
@@ -966,7 +1164,92 @@ class RoomJoinTestCase(RoomBase):
# Now make the callback deny all room joins, and check that a join actually fails.
return_value = False
- self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
+ self.helper.join(
+ self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2
+ )
+
+ def test_spam_checker_may_join_room(self) -> None:
+ """Tests that the user_may_join_room spam checker callback is correctly called
+ and blocks room joins when needed.
+
+ This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`.
+ """
+
+ # Register a dummy callback. Make it allow all room joins for now.
+ return_value: Union[
+ Literal["NOT_SPAM"], Tuple[Codes, dict], Codes
+ ] = synapse.module_api.NOT_SPAM
+
+ async def user_may_join_room(
+ userid: str,
+ room_id: str,
+ is_invited: bool,
+ ) -> Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes]:
+ return return_value
+
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
+ self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
+
+ # Join a first room, without being invited to it.
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room1,
+ False,
+ ),
+ )
+ self.assertEqual(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Join a second room, this time with an invite for it.
+ self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(self.room2, self.user2, tok=self.tok2)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = (
+ (
+ self.user2,
+ self.room2,
+ True,
+ ),
+ )
+ self.assertEqual(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ # We pick an arbitrary Codes rather than the default `Codes.FORBIDDEN`.
+ return_value = Codes.CONSENT_NOT_GIVEN
+ self.helper.invite(self.room3, self.user1, self.user2, tok=self.tok1)
+ self.helper.join(
+ self.room3,
+ self.user2,
+ expect_code=HTTPStatus.FORBIDDEN,
+ expect_errcode=return_value,
+ tok=self.tok2,
+ )
+
+ # Now make the callback deny all room joins, and check that a join actually fails.
+ # As above, with the experimental extension that lets us return dictionaries.
+ return_value = (Codes.BAD_ALIAS, {"another_field": "12345"})
+ self.helper.join(
+ self.room3,
+ self.user2,
+ expect_code=HTTPStatus.FORBIDDEN,
+ expect_errcode=return_value[0],
+ tok=self.tok2,
+ expect_additional_fields=return_value[1],
+ )
class RoomJoinRatelimitTestCase(RoomBase):
@@ -1016,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
channel = self.make_request("PUT", path, {"displayname": "John Doe"})
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that all the rooms have been sent a profile update into.
for room_id in room_ids:
@@ -1081,40 +1364,153 @@ class RoomMessagesTestCase(RoomBase):
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
channel = self.make_request("PUT", path, b"{}")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'{"_name":"bo"}')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'{"nao')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b"text only")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
channel = self.make_request("PUT", path, b"")
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
def test_rooms_messages_sent(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
# custom message types
content = b'{"body":"test","msgtype":"test.custom.text"}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# m.text message type
path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
content = b'{"body":"test2","msgtype":"m.text"}'
channel = self.make_request("PUT", path, content)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
+
+ @parameterized.expand(
+ [
+ # Allow
+ param(
+ name="NOT_SPAM",
+ value="NOT_SPAM",
+ expected_code=HTTPStatus.OK,
+ expected_fields={},
+ ),
+ param(
+ name="False",
+ value=False,
+ expected_code=HTTPStatus.OK,
+ expected_fields={},
+ ),
+ # Block
+ param(
+ name="scalene string",
+ value="ANY OTHER STRING",
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={"errcode": "M_FORBIDDEN"},
+ ),
+ param(
+ name="True",
+ value=True,
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={"errcode": "M_FORBIDDEN"},
+ ),
+ param(
+ name="Code",
+ value=Codes.LIMIT_EXCEEDED,
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={"errcode": "M_LIMIT_EXCEEDED"},
+ ),
+ param(
+ name="Tuple",
+ value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}),
+ expected_code=HTTPStatus.FORBIDDEN,
+ expected_fields={
+ "errcode": "M_SERVER_NOT_TRUSTED",
+ "additional_field": "12345",
+ },
+ ),
+ ]
+ )
+ def test_spam_checker_check_event_for_spam(
+ self,
+ name: str,
+ value: Union[str, bool, Codes, Tuple[Codes, JsonDict]],
+ expected_code: int,
+ expected_fields: dict,
+ ) -> None:
+ class SpamCheck:
+ mock_return_value: Union[
+ str, bool, Codes, Tuple[Codes, JsonDict], bool
+ ] = "NOT_SPAM"
+ mock_content: Optional[JsonDict] = None
+
+ async def check_event_for_spam(
+ self,
+ event: synapse.events.EventBase,
+ ) -> Union[str, Codes, Tuple[Codes, JsonDict], bool]:
+ self.mock_content = event.content
+ return self.mock_return_value
+
+ spam_checker = SpamCheck()
+
+ self.hs.get_spam_checker()._check_event_for_spam_callbacks.append(
+ spam_checker.check_event_for_spam
+ )
+
+ # Inject `value` as mock_return_value
+ spam_checker.mock_return_value = value
+ path = "/rooms/%s/send/m.room.message/check_event_for_spam_%s" % (
+ urlparse.quote(self.room_id),
+ urlparse.quote(name),
+ )
+ body = "test-%s" % name
+ content = '{"body":"%s","msgtype":"m.text"}' % body
+ channel = self.make_request("PUT", path, content)
+
+ # Check that the callback has witnessed the correct event.
+ self.assertIsNotNone(spam_checker.mock_content)
+ if (
+ spam_checker.mock_content is not None
+ ): # Checked just above, but mypy doesn't know about that.
+ self.assertEqual(
+ spam_checker.mock_content["body"], body, spam_checker.mock_content
+ )
+
+ # Check that we have the correct result.
+ self.assertEqual(expected_code, channel.code, msg=channel.result["body"])
+ for expected_key, expected_value in expected_fields.items():
+ self.assertEqual(
+ channel.json_body.get(expected_key, None),
+ expected_value,
+ "Field %s absent or invalid " % expected_key,
+ )
class RoomPowerLevelOverridesTestCase(RoomBase):
@@ -1239,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am allowed
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
def test_normal_user_can_not_post_state_event(self) -> None:
# Given I am a normal member of a room
@@ -1253,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed because state events require PL>=50
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
"user_level (0) < send_level (50)",
@@ -1280,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am allowed
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
@unittest.override_config(
{
@@ -1308,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
@unittest.override_config(
{
@@ -1336,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
channel = self.make_request("PUT", path, "{}")
# Then I am not allowed
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
+ "user_level (0) < send_level (1)",
@@ -1367,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase):
# Then I am not allowed because the public_chat config does not
# affect this room, because this room is a private_chat
- self.assertEqual(403, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"])
self.assertEqual(
"You don't have permission to post that to the room. "
+ "user_level (0) < send_level (50)",
@@ -1386,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase):
def test_initial_sync(self) -> None:
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertEqual(self.room_id, channel.json_body["room_id"])
self.assertEqual("join", channel.json_body["membership"])
@@ -1429,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase):
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEqual(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
@@ -1440,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase):
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
- self.assertEqual(200, channel.code)
+ self.assertEqual(HTTPStatus.OK, channel.code)
self.assertTrue("start" in channel.json_body)
self.assertEqual(token, channel.json_body["start"])
self.assertTrue("chunk" in channel.json_body)
@@ -1479,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
@@ -1507,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 1, [event["content"] for event in chunk])
@@ -1524,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase):
json.dumps({"types": [EventTypes.Message]}),
),
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
chunk = channel.json_body["chunk"]
self.assertEqual(len(chunk), 0, [event["content"] for event in chunk])
@@ -1652,14 +2048,97 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
def test_restricted_no_auth(self) -> None:
channel = self.make_request("GET", self.url)
- self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
def test_restricted_auth(self) -> None:
self.register_user("user", "pass")
tok = self.login("user", "pass")
channel = self.make_request("GET", self.url, access_token=tok)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+
+class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+
+ config = self.default_config()
+ config["allow_public_rooms_without_auth"] = True
+ self.hs = self.setup_test_homeserver(config=config)
+ self.url = b"/_matrix/client/r0/publicRooms"
+
+ return self.hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ user = self.register_user("alice", "pass")
+ self.token = self.login(user, "pass")
+
+ # Create a room
+ self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=self.token,
+ )
+ # Create a space
+ self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={
+ "visibility": "public",
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ },
+ tok=self.token,
+ )
+
+ def make_public_rooms_request(
+ self, room_types: Union[List[Union[str, None]], None]
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ channel = self.make_request(
+ "POST",
+ self.url,
+ {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
+ self.token,
+ )
+ chunk = channel.json_body["chunk"]
+ count = channel.json_body["total_room_count_estimate"]
+
+ self.assertEqual(len(chunk), count)
+
+ return chunk, count
+
+ def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(None)
+
+ self.assertEqual(count, 2)
+
+ def test_returns_only_rooms_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request([None])
+
+ self.assertEqual(count, 1)
+ self.assertEqual(chunk[0].get("room_type", None), None)
+
+ def test_returns_only_space_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(["m.space"])
+
+ self.assertEqual(count, 1)
+ self.assertEqual(chunk[0].get("room_type", None), "m.space")
+
+ def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(["m.space", None])
+
+ self.assertEqual(count, 2)
+
+ def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
+ chunk, count = self.make_public_rooms_request([])
+
+ self.assertEqual(count, 2)
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
@@ -1686,7 +2165,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
- search_filter = {"generic_search_term": "foobar"}
+ search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
channel = self.make_request(
"POST",
@@ -1694,7 +2173,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
content={"filter": search_filter},
access_token=self.token,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined]
"testserv",
@@ -1711,11 +2190,11 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# The `get_public_rooms` should be called again if the first call fails
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
- HttpResponseException(404, "Not Found", b""),
+ HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
make_awaitable({}),
)
- search_filter = {"generic_search_term": "foobar"}
+ search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
channel = self.make_request(
"POST",
@@ -1723,7 +2202,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
content={"filter": search_filter},
access_token=self.token,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined]
[
@@ -1769,21 +2248,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
# Set a profile for the test user
self.displayname = "test user"
- data = {"displayname": self.displayname}
- request_data = json.dumps(data)
+ request_data = {"displayname": self.displayname}
channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
request_data,
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_per_room_profile_forbidden(self) -> None:
- data = {"membership": "join", "displayname": "other test user"}
- request_data = json.dumps(data)
+ request_data = {"membership": "join", "displayname": "other test user"}
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
@@ -1791,7 +2268,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
request_data,
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -1799,7 +2276,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
res_displayname = channel.json_body["content"]["displayname"]
self.assertEqual(res_displayname, self.displayname, channel.result)
@@ -1833,7 +2310,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1847,7 +2324,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1861,7 +2338,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1875,7 +2352,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1887,7 +2364,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1899,7 +2376,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1918,7 +2395,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
content={"reason": reason},
access_token=self.second_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self._check_for_reason(reason)
@@ -1930,7 +2407,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
),
access_token=self.creator_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
event_content = channel.json_body
@@ -1978,7 +2455,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2008,7 +2485,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2043,7 +2520,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2123,16 +2600,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
def test_search_filter_labels(self) -> None:
"""Test that we can filter by a label on a /search request."""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2160,16 +2635,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
def test_search_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /search request."""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_NOT_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_NOT_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2209,16 +2682,14 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"""Test that we can filter by both a label and the absence of another label on a
/search request.
"""
- request_data = json.dumps(
- {
- "search_categories": {
- "room_events": {
- "search_term": "label",
- "filter": self.FILTER_LABELS_NOT_LABELS,
- }
+ request_data = {
+ "search_categories": {
+ "room_events": {
+ "search_term": "label",
+ "filter": self.FILTER_LABELS_NOT_LABELS,
}
}
- )
+ }
self._send_labelled_messages_in_room()
@@ -2391,7 +2862,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["chunk"]
@@ -2496,7 +2967,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2562,7 +3033,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
% (self.room_id, event_id),
access_token=invited_tok,
)
- self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
events_before = channel.json_body["events_before"]
@@ -2663,8 +3134,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
@@ -2693,8 +3163,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
- data = {"room_id": self.room_id}
- request_data = json.dumps(data)
+ request_data = {"room_id": self.room_id}
channel = self.make_request(
"PUT", url, request_data, access_token=self.room_owner_tok
@@ -2720,7 +3189,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
"rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
- json.dumps(content),
+ content,
access_token=self.room_owner_tok,
)
self.assertEqual(channel.code, expected_code, channel.result)
@@ -2845,11 +3314,16 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
- def test_threepid_invite_spamcheck(self) -> None:
+ def test_threepid_invite_spamcheck_deprecated(self) -> None:
+ """
+ Test allowing/blocking threepid invites with a spam-check module.
+
+ In this test, we use the deprecated API in which callbacks return a bool.
+ """
# Mock a few functions to prevent the test from failing due to failing to talk to
- # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
+ # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
- make_invite_mock = Mock(return_value=make_awaitable(0))
+ make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
self.hs.get_identity_handler().lookup_3pid = Mock(
return_value=make_awaitable(None),
@@ -2901,3 +3375,107 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()
+
+ def test_threepid_invite_spamcheck(self) -> None:
+ """
+ Test allowing/blocking threepid invites with a spam-check module.
+
+ In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+ # Mock a few functions to prevent the test from failing due to failing to talk to
+ # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
+ # can check its call_count later on during the test.
+ make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
+ self.hs.get_identity_handler().lookup_3pid = Mock(
+ return_value=make_awaitable(None),
+ )
+
+ # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
+ # allow everything for now.
+ # `spec` argument is needed for this function mock to have `__qualname__`, which
+ # is needed for `Measure` metrics buried in SpamChecker.
+ mock = Mock(
+ return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+ spec=lambda *x: None,
+ )
+ self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
+
+ # Send a 3PID invite into the room and check that it succeeded.
+ email_to_invite = "teresa@example.com"
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Check that the callback was called with the right params.
+ mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
+
+ # Check that the call to send the invite was made.
+ make_invite_mock.assert_called_once()
+
+ # Now change the return value of the callback to deny any invite and test that
+ # we can't send the invite. We pick an arbitrary error code to be able to check
+ # that the same code has been returned
+ mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN)
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
+
+ # Run variant with `Tuple[Codes, dict]`.
+ mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"}))
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": email_to_invite,
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT)
+ self.assertEqual(channel.json_body["field"], "value")
+
+ # Also check that it stopped before calling _make_and_store_3pid_invite.
+ make_invite_mock.assert_called_once()
+
+ def test_400_missing_param_without_id_access_token(self) -> None:
+ """
+ Test that a 3pid invite request returns 400 M_MISSING_PARAM
+ if we do not include id_access_token.
+ """
+ channel = self.make_request(
+ method="POST",
+ path="/rooms/" + self.room_id + "/invite",
+ content={
+ "id_server": "example.com",
+ "medium": "email",
+ "address": "teresa@example.com",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a28..c807a37bc2 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import (
room_upgrade_rest_servlet,
)
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -97,7 +97,12 @@ class RoomTestCase(_ShadowBannedBase):
channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
- {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ {
+ "id_server": "test",
+ "medium": "email",
+ "address": "test@test.test",
+ "id_access_token": "anytoken",
+ },
access_token=self.banned_access_token,
)
self.assertEqual(200, channel.code, channel.result)
@@ -275,7 +280,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
@@ -310,7 +315,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
- self.banned_user_id,
+ create_requester(self.banned_user_id),
room_id,
"m.room.member",
self.banned_user_id,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index e3efd1f1b0..0af643ecd9 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
from tests.server import TimedOutException
-from tests.unittest import override_config
class FilterTestCase(unittest.HomeserverTestCase):
@@ -390,6 +389,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+
+ return self.setup_test_homeserver(config=config)
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s"
self.next_batch = "s0"
@@ -408,7 +412,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Join the second user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_private_read_receipts(self) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -416,7 +419,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Send a private read receipt to tell the server the first user's message was read
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok2,
)
@@ -425,7 +428,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that the first user can't see the other user's private read receipt
self.assertIsNone(self._get_read_receipt())
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_public_receipt_can_override_private(self) -> None:
"""
Sending a public read receipt to the same event which has a private read
@@ -456,7 +458,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
# Test that we did override the private read receipt
self.assertNotEqual(self._get_read_receipt(), None)
- @override_config({"experimental_features": {"msc2285_enabled": True}})
def test_private_receipt_cannot_override_public(self) -> None:
"""
Sending a private read receipt to the same event which has a public read
@@ -543,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
config = super().default_config()
config["experimental_features"] = {
"msc2654_enabled": True,
- "msc2285_enabled": True,
}
return config
@@ -606,11 +606,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(1)
# Send a read receipt to tell the server we've read the latest event.
- body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8")
channel = self.make_request(
"POST",
f"/rooms/{self.room_id}/read_markers",
- body,
+ {ReceiptTypes.READ: res["event_id"]},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -625,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Send a read receipt to tell the server we've read the latest event.
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
{},
access_token=self.tok,
)
@@ -701,7 +700,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(5)
res2 = self.helper.send(self.room_id, "hello", tok=self.tok2)
- # Make sure both m.read and org.matrix.msc2285.read.private advance
+ # Make sure both m.read and m.read.private advance
channel = self.make_request(
"POST",
f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}",
@@ -713,16 +712,21 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0)
- # We test for both receipt types that influence notification counts
- @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE])
- def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None:
+ # We test for all three receipt types that influence notification counts
+ @parameterized.expand(
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ]
+ )
+ def test_read_receipts_only_go_down(self, receipt_type: str) -> None:
# Join the new user
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@@ -733,18 +737,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
# Read last event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}",
{},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
self._check_unread_count(0)
- # Make sure neither m.read nor org.matrix.msc2285.read.private make the
+ # Make sure neither m.read nor m.read.private make the
# read receipt go up to an older event
channel = self.make_request(
"POST",
- f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}",
+ f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}",
{},
access_token=self.tok,
)
@@ -949,3 +953,24 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])
+
+ def test_incremental_sync(self) -> None:
+ """Tests that activity in the room is properly filtered out of incremental
+ syncs.
+ """
+ channel = self.make_request("GET", "/sync", access_token=self.tok)
+ self.assertEqual(channel.code, 200, channel.result)
+ next_batch = channel.json_body["next_batch"]
+
+ self.helper.send(self.excluded_room_id, tok=self.tok)
+ self.helper.send(self.included_room_id, tok=self.tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/sync?since={next_batch}",
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
+ self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 5eb0f243f7..3325d43a2f 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
+from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin
from synapse.rest.client import account, login, profile, room
@@ -113,14 +113,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
- async def _check_event_auth(
- origin: str,
- event: EventBase,
- context: EventContext,
- *args: Any,
- **kwargs: Any,
- ) -> EventContext:
- return context
+ async def _check_event_auth(origin: Any, event: Any, context: Any) -> None:
+ pass
hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
@@ -161,7 +155,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
callback.assert_called_once()
@@ -179,7 +173,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.code, 403, channel.result)
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
"""
@@ -192,12 +186,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
class NastyHackException(SynapseError):
- def error_dict(self) -> JsonDict:
+ def error_dict(self, config: Optional[HomeServerConfig]) -> JsonDict:
"""
This overrides SynapseError's `error_dict` to nastily inject
JSON into the error response.
"""
- result = super().error_dict()
+ result = super().error_dict(config)
result["nasty"] = "very"
return result
@@ -217,7 +211,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
access_token=self.tok,
)
# Check the error code
- self.assertEqual(channel.result["code"], b"429", channel.result)
+ self.assertEqual(channel.code, 429, channel.result)
# Check the JSON body has had the `nasty` key injected
self.assertEqual(
channel.json_body,
@@ -266,7 +260,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"x": "x"},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
# ... and check that it got modified
@@ -275,7 +269,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
@@ -304,7 +298,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
orig_event_id = channel.json_body["event_id"]
channel = self.make_request(
@@ -321,7 +315,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
edited_event_id = channel.json_body["event_id"]
# ... and check that they both got modified
@@ -330,7 +324,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
@@ -339,7 +333,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY")
@@ -385,7 +379,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
},
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
@@ -394,7 +388,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
- self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.code, 200, channel.result)
self.assertIn("foo", channel.json_body["content"].keys())
self.assertEqual(channel.json_body["content"]["foo"], "bar")
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 98c1039d33..5e7bf97482 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -48,10 +48,14 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.helper.join(self.room_id, self.other, tok=self.other_token)
def _upgrade_room(
- self, token: Optional[str] = None, room_id: Optional[str] = None
+ self,
+ token: Optional[str] = None,
+ room_id: Optional[str] = None,
+ expire_cache: bool = True,
) -> FakeChannel:
- # We never want a cached response.
- self.reactor.advance(5 * 60 + 1)
+ if expire_cache:
+ # We don't want a cached response.
+ self.reactor.advance(5 * 60 + 1)
if room_id is None:
room_id = self.room_id
@@ -72,9 +76,24 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, channel.result)
self.assertIn("replacement_room", channel.json_body)
- def test_not_in_room(self) -> None:
+ new_room_id = channel.json_body["replacement_room"]
+
+ # Check that the tombstone event points to the new room.
+ tombstone_event = self.get_success(
+ self.hs.get_storage_controllers().state.get_current_state_event(
+ self.room_id, EventTypes.Tombstone, ""
+ )
+ )
+ self.assertIsNotNone(tombstone_event)
+ self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
+
+ # Check that the new room exists.
+ room = self.get_success(self.store.get_room(new_room_id))
+ self.assertIsNotNone(room)
+
+ def test_never_in_room(self) -> None:
"""
- Upgrading a room should work fine.
+ A user who has never been in the room cannot upgrade the room.
"""
# The user isn't in the room.
roomless = self.register_user("roomless", "pass")
@@ -83,6 +102,16 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
channel = self._upgrade_room(roomless_token)
self.assertEqual(403, channel.code, channel.result)
+ def test_left_room(self) -> None:
+ """
+ A user who is no longer in the room cannot upgrade the room.
+ """
+ # Remove the user from the room.
+ self.helper.leave(self.room_id, self.creator, tok=self.creator_token)
+
+ channel = self._upgrade_room(self.creator_token)
+ self.assertEqual(403, channel.code, channel.result)
+
def test_power_levels(self) -> None:
"""
Another user can upgrade the room if their power level is increased.
@@ -297,3 +326,47 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.assertEqual(
create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type
)
+
+ def test_second_upgrade_from_same_user(self) -> None:
+ """A second room upgrade from the same user is deduplicated."""
+ channel1 = self._upgrade_room()
+ self.assertEqual(200, channel1.code, channel1.result)
+
+ channel2 = self._upgrade_room(expire_cache=False)
+ self.assertEqual(200, channel2.code, channel2.result)
+
+ self.assertEqual(
+ channel1.json_body["replacement_room"],
+ channel2.json_body["replacement_room"],
+ )
+
+ def test_second_upgrade_after_delay(self) -> None:
+ """A second room upgrade is not deduplicated after some time has passed."""
+ channel1 = self._upgrade_room()
+ self.assertEqual(200, channel1.code, channel1.result)
+
+ channel2 = self._upgrade_room(expire_cache=True)
+ self.assertEqual(200, channel2.code, channel2.result)
+
+ self.assertNotEqual(
+ channel1.json_body["replacement_room"],
+ channel2.json_body["replacement_room"],
+ )
+
+ def test_second_upgrade_from_different_user(self) -> None:
+ """A second room upgrade from a different user is blocked."""
+ channel = self._upgrade_room()
+ self.assertEqual(200, channel.code, channel.result)
+
+ channel = self._upgrade_room(self.other_token, expire_cache=False)
+ self.assertEqual(400, channel.code, channel.result)
+
+ def test_first_upgrade_does_not_block_second(self) -> None:
+ """A second room upgrade is not blocked when a previous upgrade attempt was not
+ allowed.
+ """
+ channel = self._upgrade_room(self.other_token)
+ self.assertEqual(403, channel.code, channel.result)
+
+ channel = self._upgrade_room(expire_cache=False)
+ self.assertEqual(200, channel.code, channel.result)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index a0788b1bb0..dd26145bf8 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -41,6 +41,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
+from synapse.api.errors import Codes
from synapse.server import HomeServer
from synapse.types import JsonDict
@@ -135,11 +136,11 @@ class RestHelper:
self.site,
"POST",
path,
- json.dumps(content).encode("utf8"),
+ content,
custom_headers=custom_headers,
)
- assert channel.result["code"] == b"%d" % expect_code, channel.result
+ assert channel.code == expect_code, channel.result
self.auth_user_id = temp_id
if expect_code == HTTPStatus.OK:
@@ -171,6 +172,8 @@ class RestHelper:
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
appservice_user_id: Optional[str] = None,
+ expect_errcode: Optional[Codes] = None,
+ expect_additional_fields: Optional[dict] = None,
) -> None:
self.change_membership(
room=room,
@@ -180,6 +183,8 @@ class RestHelper:
appservice_user_id=appservice_user_id,
membership=Membership.JOIN,
expect_code=expect_code,
+ expect_errcode=expect_errcode,
+ expect_additional_fields=expect_additional_fields,
)
def knock(
@@ -205,14 +210,12 @@ class RestHelper:
self.site,
"POST",
path,
- json.dumps(data).encode("utf8"),
+ data,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -263,6 +266,7 @@ class RestHelper:
appservice_user_id: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
expect_errcode: Optional[str] = None,
+ expect_additional_fields: Optional[dict] = None,
) -> None:
"""
Send a membership state event into a room.
@@ -303,14 +307,12 @@ class RestHelper:
self.site,
"PUT",
path,
- json.dumps(data).encode("utf8"),
+ data,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -323,6 +325,21 @@ class RestHelper:
channel.result["body"],
)
+ if expect_additional_fields is not None:
+ for expect_key, expect_value in expect_additional_fields.items():
+ assert expect_key in channel.json_body, "Expected field %s, got %s" % (
+ expect_key,
+ channel.json_body,
+ )
+ assert (
+ channel.json_body[expect_key] == expect_value
+ ), "Expected: %s at %s, got: %s, resp: %s" % (
+ expect_value,
+ expect_key,
+ channel.json_body[expect_key],
+ channel.json_body,
+ )
+
self.auth_user_id = temp_id
def send(
@@ -371,15 +388,13 @@ class RestHelper:
self.site,
"PUT",
path,
- json.dumps(content or {}).encode("utf8"),
+ content or {},
custom_headers=custom_headers,
)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -428,11 +443,9 @@ class RestHelper:
channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
- assert (
- int(channel.result["code"]) == expect_code
- ), "Expected: %d, got: %d, resp: %r" % (
+ assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
@@ -524,7 +537,7 @@ class RestHelper:
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
- int(channel.result["code"]),
+ channel.code,
channel.result["body"],
)
|