summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_account.py109
-rw-r--r--tests/rest/client/test_directory.py59
-rw-r--r--tests/rest/client/test_filter.py14
-rw-r--r--tests/rest/client/test_identity.py7
-rw-r--r--tests/rest/client/test_login.py119
-rw-r--r--tests/rest/client/test_models.py53
-rw-r--r--tests/rest/client/test_password_policy.py31
-rw-r--r--tests/rest/client/test_profile.py18
-rw-r--r--tests/rest/client/test_redactions.py4
-rw-r--r--tests/rest/client/test_register.py199
-rw-r--r--tests/rest/client/test_relations.py21
-rw-r--r--tests/rest/client/test_report_event.py11
-rw-r--r--tests/rest/client/test_retention.py4
-rw-r--r--tests/rest/client/test_rooms.py936
-rw-r--r--tests/rest/client/test_shadow_banned.py13
-rw-r--r--tests/rest/client/test_sync.py59
-rw-r--r--tests/rest/client/test_third_party_rules.py38
-rw-r--r--tests/rest/client/test_upgrade_room.py83
-rw-r--r--tests/rest/client/utils.py57
19 files changed, 1308 insertions, 527 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py

index a43a137273..c1a7fb2f8a 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py
@@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os import re from email.parser import Parser +from http import HTTPStatus from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock @@ -95,10 +95,8 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") - ) - self.assertEqual(channel.code, 403, channel.result) + channel = self.make_request("POST", "/_matrix/client/r0/login", body) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) def test_basic_password_reset(self) -> None: """Test basic password reset flow""" @@ -347,7 +345,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button @@ -362,7 +360,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -390,7 +388,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): new_password: str, session_id: str, client_secret: str, - expected_code: int = 200, + expected_code: int = HTTPStatus.OK, ) -> None: channel = self.make_request( "POST", @@ -479,20 +477,18 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertEqual(memberships[0].room_id, room_id, memberships) def deactivate(self, user_id: str, tok: str) -> None: - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, 200, channel.json_body) class WhoamiTestCase(unittest.HomeserverTestCase): @@ -645,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_no_at(self) -> None: self._request_token_invalid_email( "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, + expected_errcode=Codes.BAD_JSON, expected_error="Unable to parse email address", ) def test_add_email_two_at(self) -> None: self._request_token_invalid_email( "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, + expected_errcode=Codes.BAD_JSON, expected_error="Unable to parse email address", ) def test_add_email_bad_format(self) -> None: self._request_token_invalid_email( "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, + expected_errcode=Codes.BAD_JSON, expected_error="Unable to parse email address", ) @@ -715,7 +711,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -725,7 +723,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email(self) -> None: @@ -747,7 +745,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -756,7 +754,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email_if_disabled(self) -> None: @@ -781,7 +779,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) @@ -817,7 +817,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -827,7 +829,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_no_valid_token(self) -> None: @@ -852,7 +854,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -862,7 +866,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @override_config({"next_link_domain_whitelist": None}) @@ -872,7 +876,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/good/site", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -884,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="some-protocol://abcdefghijklmopqrstuvwxyz", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -895,7 +899,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="file:///host/path", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) @@ -907,28 +911,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link=None, - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.com/some/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.org/some/also/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://bad.example.org/some/bad/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": []}) @@ -940,7 +944,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) def _request_token( @@ -948,8 +952,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): email: str, client_secret: str, next_link: Optional[str] = None, - expect_code: int = 200, - ) -> str: + expect_code: int = HTTPStatus.OK, + ) -> Optional[str]: """Request a validation token to add an email address to a user's account Args: @@ -959,7 +963,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): expect_code: Expected return code of the call Returns: - The ID of the new threepid validation session + The ID of the new threepid validation session, or None if the response + did not contain a session ID. """ body = {"client_secret": client_secret, "email": email, "send_attempt": 1} if next_link: @@ -992,16 +997,18 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(expected_errcode, channel.json_body["errcode"]) - self.assertEqual(expected_error, channel.json_body["error"]) + self.assertIn(expected_error, channel.json_body["error"]) def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -1051,7 +1058,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -1060,7 +1067,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} @@ -1091,7 +1098,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that not providing any MXID raises an error.""" self._test_status( users=None, - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.MISSING_PARAM, ) @@ -1099,7 +1106,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that providing an invalid MXID raises an error.""" self._test_status( users=["bad:test"], - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.INVALID_PARAM, ) @@ -1285,7 +1292,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): def _test_status( self, users: Optional[List[str]], - expected_status_code: int = 200, + expected_status_code: int = HTTPStatus.OK, expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_failures: Optional[List[str]] = None, expected_errcode: Optional[str] = None, diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index aca03afd0e..7a88aa2cda 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py
@@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor +from synapse.appservice import ApplicationService from synapse.rest import admin from synapse.rest.client import directory, login, room from synapse.server import HomeServer @@ -96,8 +96,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # We use deliberately a localpart under the length threshold so # that we can make sure that the check is done on the whole alias. - data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -109,8 +108,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # Check with an alias of allowed length. There should already be # a test that ensures it works in test_register.py, but let's be # as cautious as possible here. - data = {"room_alias_name": random_string(5)} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(5)} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -129,6 +127,38 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + def test_deleting_alias_via_directory_appservice(self) -> None: + user_id = "@as:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + id="1234", + namespaces={"aliases": [{"regex": "#asns-*", "exclusive": True}]}, + sender=user_id, + ) + self.hs.get_datastores().main.services_cache.append(appservice) + + # Add an alias for the room, as the appservice + alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string() + request_data = {"room_id": self.room_id} + + channel = self.make_request( + "PUT", + f"/_matrix/client/r0/directory/room/{alias}", + request_data, + access_token=as_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # Then try to remove the alias, as the appservice + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=as_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + def test_deleting_nonexistant_alias(self) -> None: # Check that no alias exists alias = "#potato:test" @@ -159,8 +189,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.hs.hostname, ) - data = {"aliases": [self.random_alias(alias_length)]} - request_data = json.dumps(data) + request_data = {"aliases": [self.random_alias(alias_length)]} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok @@ -172,8 +201,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) -> str: alias = self.random_alias(alias_length) url = "/_matrix/client/r0/directory/room/%s" % alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok @@ -181,6 +209,19 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, expected_code, channel.result) return alias + def test_invalid_alias(self) -> None: + alias = "#potato" + channel = self.make_request( + "GET", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual( + channel.json_body["errcode"], "M_INVALID_PARAM", channel.json_body + ) + def random_alias(self, length: int) -> str: return RoomAlias(random_string(length), self.hs.hostname).to_string() diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 823e8ab8c4..afc8d641be 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py
@@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.EXAMPLE_FILTER_JSON, ) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.get_success( self.store.get_user_filter(user_localpart="apple", filter_id=0) @@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase): self.EXAMPLE_FILTER_JSON, ) - self.assertEqual(channel.result["code"], b"403") + self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_add_filter_non_local_user(self) -> None: @@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase): ) self.hs.is_mine = _is_mine - self.assertEqual(channel.result["code"], b"403") + self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) def test_get_filter(self) -> None: @@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) ) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, self.EXAMPLE_FILTER) def test_get_filter_non_existant(self) -> None: @@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) ) - self.assertEqual(channel.result["code"], b"404") + self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) # Currently invalid params do not have an appropriate errcode @@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) ) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.code, 400) # No ID also returns an invalid_id error def test_get_filter_no_id(self) -> None: @@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) ) - self.assertEqual(channel.result["code"], b"400") + self.assertEqual(channel.code, 400) diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 299b9d21e2..b0c8215744 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -26,7 +25,6 @@ from tests import unittest class IdentityTestCase(unittest.HomeserverTestCase): - servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -34,7 +32,6 @@ class IdentityTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) @@ -51,12 +48,12 @@ class IdentityTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] - params = { + request_data = { "id_server": "testis", "medium": "email", "address": "test@example.com", + "id_access_token": tok, } - request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") channel = self.make_request( b"POST", request_url, request_data, access_token=tok diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f4ea1209d9..e2a4d98275 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py
@@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import time import urllib.parse -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from unittest.mock import Mock from urllib.parse import urlencode @@ -41,7 +40,7 @@ from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless try: - import jwt + from authlib.jose import jwk, jwt HAS_JWT = True except ImportError: @@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config( { @@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config( { @@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_soft_logout(self) -> None: @@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we shouldn't be able to make requests without an access token channel = self.make_request(b"GET", TEST_URL) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") # log in as normal @@ -354,7 +353,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # Now try to hard logout this session channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( @@ -380,7 +379,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # Now try to hard log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_login_with_overly_long_device_id_fails(self) -> None: self.register_user("mickey", "cheese") @@ -399,7 +398,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "/_matrix/client/v3/login", - json.dumps(body).encode("utf8"), + body, custom_headers=None, ) @@ -841,7 +840,7 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertIn(b"SSO account deactivated", channel.result["body"]) -@skip_unless(HAS_JWT, "requires jwt") +@skip_unless(HAS_JWT, "requires authlib") class JWTTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -866,11 +865,9 @@ class JWTTestCase(unittest.HomeserverTestCase): return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) - if isinstance(result, bytes): - return result.decode("ascii") - return result + header = {"alg": self.jwt_algorithm} + result: bytes = jwt.encode(header, payload, secret) + return result.decode("ascii") def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} @@ -880,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -899,25 +896,26 @@ class JWTTestCase(unittest.HomeserverTestCase): def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Signature has expired" + channel.json_body["error"], + "JWT validation failed: expired_token: The token is expired", ) def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - "JWT validation failed: The token is not yet valid (nbf)", + "JWT validation failed: invalid_token: The token is not valid yet", ) def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @@ -926,30 +924,31 @@ class JWTTestCase(unittest.HomeserverTestCase): """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid issuer" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "iss"', ) # Not providing an issuer. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - 'JWT validation failed: Token is missing the "iss" claim', + 'JWT validation failed: missing_claim: Missing "iss" claim', ) def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) @@ -957,52 +956,54 @@ class JWTTestCase(unittest.HomeserverTestCase): """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid audience. channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "aud"', ) # Not providing an audience. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - 'JWT validation failed: Token is missing the "aud" claim', + 'JWT validation failed: missing_claim: Missing "aud" claim', ) def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "aud"', ) def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -1010,7 +1011,7 @@ class JWTTestCase(unittest.HomeserverTestCase): # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # RSS256, with a public key configured in synapse as "jwt_secret", and tokens # signed by the private key. -@skip_unless(HAS_JWT, "requires jwt") +@skip_unless(HAS_JWT, "requires authlib") class JWTPubKeyTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, @@ -1071,11 +1072,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") - if isinstance(result, bytes): - return result.decode("ascii") - return result + header = {"alg": "RS256"} + if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"): + secret = jwk.dumps(secret, kty="RSA") + result: bytes = jwt.encode(header, payload, secret) + return result.decode("ascii") def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} @@ -1084,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -1150,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" @@ -1164,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" @@ -1178,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" @@ -1192,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice @@ -1206,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) @skip_unless(HAS_OIDC, "requires OIDC") diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py new file mode 100644
index 0000000000..a9da00665e --- /dev/null +++ b/tests/rest/client/test_models.py
@@ -0,0 +1,53 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from pydantic import ValidationError + +from synapse.rest.client.models import EmailRequestTokenBody + + +class EmailRequestTokenBodyTestCase(unittest.TestCase): + base_request = { + "client_secret": "hunter2", + "email": "alice@wonderland.com", + "send_attempt": 1, + } + + def test_token_required_if_id_server_provided(self) -> None: + with self.assertRaises(ValidationError): + EmailRequestTokenBody.parse_obj( + { + **self.base_request, + "id_server": "identity.wonderland.com", + } + ) + with self.assertRaises(ValidationError): + EmailRequestTokenBody.parse_obj( + { + **self.base_request, + "id_server": "identity.wonderland.com", + "id_access_token": None, + } + ) + + def test_token_typechecked_when_id_server_provided(self) -> None: + with self.assertRaises(ValidationError): + EmailRequestTokenBody.parse_obj( + { + **self.base_request, + "id_server": "identity.wonderland.com", + "id_access_token": 1337, + } + ) diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3a74d2e96c..e19d21d6ee 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py
@@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -89,7 +88,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_too_short(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request_data = {"username": "kermit", "password": "shorty"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -100,7 +99,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_digit(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request_data = {"username": "kermit", "password": "longerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -111,7 +110,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_symbol(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request_data = {"username": "kermit", "password": "l0ngerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -122,7 +121,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_uppercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request_data = {"username": "kermit", "password": "l0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -133,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_lowercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -144,7 +143,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_compliant(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request_data = {"username": "kermit", "password": "L0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has @@ -161,16 +160,14 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", compliant_password) tok = self.login("kermit", compliant_password) - request_data = json.dumps( - { - "new_password": not_compliant_password, - "auth": { - "password": compliant_password, - "type": LoginType.PASSWORD, - "user": user_id, - }, - } - ) + request_data = { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } channel = self.make_request( "POST", "/_matrix/client/r0/account/password", diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 77c3ced42e..8de5a342ae 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py
@@ -13,6 +13,8 @@ # limitations under the License. """Tests REST events for /profile paths.""" +import urllib.parse +from http import HTTPStatus from typing import Any, Dict, Optional from twisted.test.proto_helpers import MemoryReactor @@ -49,6 +51,12 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "owner") + def test_get_displayname_rejects_bad_username(self) -> None: + channel = self.make_request( + "GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname" + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + def test_set_displayname(self) -> None: channel = self.make_request( "PUT", @@ -145,18 +153,22 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def _get_displayname(self, name: Optional[str] = None) -> str: + def _get_displayname(self, name: Optional[str] = None) -> Optional[str]: channel = self.make_request( "GET", "/profile/%s/displayname" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) - return channel.json_body["displayname"] + # FIXME: If a user has no displayname set, Synapse returns 200 and omits a + # displayname from the response. This contradicts the spec, see #13137. + return channel.json_body.get("displayname") - def _get_avatar_url(self, name: Optional[str] = None) -> str: + def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]: channel = self.make_request( "GET", "/profile/%s/avatar_url" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) + # FIXME: If a user has no avatar set, Synapse returns 200 and omits an + # avatar_url from the response. This contradicts the spec, see #13137. return channel.json_body.get("avatar_url") @unittest.override_config({"max_avatar_size": 50}) diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 7401b5e0c0..be4c67d68e 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py
@@ -76,12 +76,12 @@ class RedactionsTestCase(HomeserverTestCase): path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) channel = self.make_request("POST", path, content={}, access_token=access_token) - self.assertEqual(int(channel.result["code"]), expect_code) + self.assertEqual(channel.code, expect_code) return channel.json_body def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: channel = self.make_request("GET", "sync", access_token=self.mod_access_token) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, 200) room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index afb08b2736..b781875d52 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py
@@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime -import json import os from typing import Any, Dict, List, Tuple @@ -62,15 +61,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps( - {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = { + "username": "as_user_kermit", + "type": APP_SERVICE_REGISTRATION_TYPE, + } channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) @@ -85,49 +85,46 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps({"username": "as_user_kermit"}) + request_data = {"username": "as_user_kermit"} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, 400, msg=channel.result) def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists - request_data = json.dumps( - {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) def test_POST_bad_password(self) -> None: - request_data = json.dumps({"username": "kermit", "password": 666}) + request_data = {"username": "kermit", "password": 666} channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, 400, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self) -> None: - request_data = json.dumps({"username": 777, "password": "monkey"}) + request_data = {"username": 777, "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, 400, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self) -> None: user_id = "@kermit:test" device_id = "frogfone" - params = { + request_data = { "username": "kermit", "password": "monkey", "device_id": device_id, "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) det_data = { @@ -135,17 +132,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) def test_POST_disabled_registration(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request_data = {"username": "kermit", "password": "monkey"} self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -156,7 +153,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self) -> None: @@ -164,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @@ -174,40 +171,39 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request(b"POST", url, b"{}") if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self) -> None: for i in range(0, 6): - params = { + request_data = { "username": "kermit" + str(i), "password": "monkey", "device_id": "frogfone", "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) @override_config({"registration_requires_token": True}) def test_POST_registration_requires_token(self) -> None: @@ -234,8 +230,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } # Request without auth to get flows and session - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one # flow would be a subset of another flow. @@ -251,9 +247,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -262,14 +257,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.DUMMY, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, params) det_data = { "user_id": f"@{username}:{self.hs.hostname}", "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 @@ -290,7 +284,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Test with token param missing (invalid) @@ -298,22 +292,22 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.REGISTRATION_TOKEN, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -337,9 +331,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} # Do 2 requests without auth to get two session IDs - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with session1 and check `pending` is 1 @@ -348,9 +342,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Repeat request to make sure pending isn't increased again - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) pending = self.get_success( store.db_pool.simple_select_one_onecol( "registration_tokens", @@ -366,14 +360,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session2, } - channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params2) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check pending=0 and completed=1 res = self.get_success( store.db_pool.simple_select_one( @@ -386,8 +380,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(res["completed"], 1) # Check auth still fails when using token with session2 - channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params2) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -411,7 +405,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) params: JsonDict = {"username": "kermit", "password": "monkey"} # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Check authentication fails with expired token @@ -420,8 +414,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + channel = self.make_request(b"POST", self.url, params) + self.assertEqual(channel.code, 401, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -435,7 +429,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Check authentication succeeds - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -460,9 +454,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do 2 requests without auth to get two session IDs params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with both sessions @@ -471,18 +465,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) params2["auth"] = { "type": LoginType.REGISTRATION_TOKEN, "token": token, "session": session2, } - self.make_request(b"POST", self.url, json.dumps(params2)) + self.make_request(b"POST", self.url, params2) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check `result` of registration token stage for session1 is `True` result1 = self.get_success( @@ -550,7 +544,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do request without auth to get a session ID params: JsonDict = {"username": "kermit", "password": "monkey"} - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Use token @@ -559,7 +553,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - self.make_request(b"POST", self.url, json.dumps(params)) + self.make_request(b"POST", self.url, params) # Delete token self.get_success( @@ -576,7 +570,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_advertised_flows(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we only expect the dummy flow @@ -592,14 +586,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "require_at_registration": True, }, "account_threepid_delegates": { - "email": "https://id_server", "msisdn": "https://id_server", }, + "email": {"notif_from": "Synapse <synapse@example.com>"}, } ) def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] self.assertCountEqual( @@ -631,7 +625,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) def test_advertised_flows_no_msisdn_email_required(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, 401, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we expect all four combinations of 3pid @@ -803,13 +797,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -827,15 +821,14 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = {"user_id": user_id} - request_data = json.dumps(params) + request_data = {"user_id": user_id} channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_manual_expire(self) -> None: user_id = self.register_user("kermit", "monkey") @@ -845,19 +838,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -870,25 +862,24 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Try to log the user out channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -963,7 +954,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -981,7 +972,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Move 1 day forward. Try to renew with the same token again. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1001,14 +992,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) def test_renewal_invalid_token(self) -> None: # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"404", channel.result) + self.assertEqual(channel.code, 404, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1032,7 +1023,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1041,16 +1032,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) @@ -1107,7 +1096,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1187,7 +1176,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["valid"], True) def test_GET_token_invalid(self) -> None: @@ -1196,7 +1185,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["valid"], False) @override_config( @@ -1212,10 +1201,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): ) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -1223,4 +1212,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 62e4db23ef..651f4f415d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py
@@ -728,6 +728,7 @@ class RelationsTestCase(BaseRelationsTestCase): class RelationPaginationTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_basic_paginate_relations(self) -> None: """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") @@ -799,7 +800,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): ) expected_event_ids.append(channel.json_body["event_id"]) - prev_token = "" + prev_token: Optional[str] = "" found_event_ids: List[str] = [] for _ in range(20): from_token = "" @@ -998,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6) + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) def test_annotation_to_annotation(self) -> None: """Any relation to an annotation should be ignored.""" @@ -1034,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) def test_thread(self) -> None: """ @@ -1059,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): participated, bundled_aggregations.get("current_user_participated") ) # The latest thread event has some fields that don't matter. + self.assertIn("latest_event", bundled_aggregations) self.assert_dict( { "content": { @@ -1071,28 +1073,28 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): "sender": self.user2_id, "type": "m.room.test", }, - bundled_aggregations.get("latest_event"), + bundled_aggregations["latest_event"], ) return assert_thread # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # # Note that this re-uses some cached values, so the total number of # queries is much smaller. self._test_bundled_aggregations( - RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token + RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token ) # A user with no interactions with the thread: they have not participated. user3_id, user3_token = self._create_user("charlie") self.helper.join(self.room, user=user3_id, tok=user3_token) self._test_bundled_aggregations( - RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token + RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token ) def test_thread_with_bundled_aggregations_for_latest(self) -> None: @@ -1111,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertEqual(2, bundled_aggregations.get("count")) self.assertTrue(bundled_aggregations.get("current_user_participated")) # The latest thread event has some fields that don't matter. + self.assertIn("latest_event", bundled_aggregations) self.assert_dict( { "content": { @@ -1123,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): "sender": self.user_id, "type": "m.room.test", }, - bundled_aggregations.get("latest_event"), + bundled_aggregations["latest_event"], ) # Check the unsigned field on the latest event. self.assert_dict( @@ -1139,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) def test_nested_thread(self) -> None: """ diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 20a259fc43..7cb1017a4a 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py
@@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json - from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -77,11 +75,6 @@ class ReportEventTestCase(unittest.HomeserverTestCase): def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( - "POST", - self.report_path, - json.dumps(data), - access_token=self.other_user_tok, - ) - self.assertEqual( - response_status, int(channel.result["code"]), msg=channel.result["body"] + "POST", self.report_path, data, access_token=self.other_user_tok ) + self.assertEqual(response_status, channel.code, msg=channel.result["body"]) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c113354..9c8c1889d3 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, create_requester from synapse.util import Clock from synapse.visibility import filter_events_for_client @@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): message_handler = self.hs.get_message_handler() create_event = self.get_success( message_handler.get_room_data( - self.user_id, room_id, EventTypes.Create, state_key="" + create_requester(self.user_id), room_id, EventTypes.Create, state_key="" ) ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index f523d89b8f..c7eb88d33f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -18,10 +18,14 @@ """Tests REST events for /rooms paths.""" import json -from typing import Any, Dict, Iterable, List, Optional +from http import HTTPStatus +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from unittest.mock import Mock, call from urllib import parse as urlparse +from parameterized import param, parameterized +from typing_extensions import Literal + from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,7 +34,9 @@ from synapse.api.constants import ( EventContentFields, EventTypes, Membership, + PublicRoomsFilterFields, RelationTypes, + RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus @@ -42,6 +48,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest +from tests.http.server._base import make_request_with_cancellation_test from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -98,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -106,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -128,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_topic_perms(self) -> None: topic_content = b'{"topic":"My Topic Name"}' @@ -159,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -188,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -214,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room: str, members: Iterable = frozenset(), expect_code: int = 200 @@ -303,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( room=room, src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) def test_joined_permissions(self) -> None: @@ -336,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of other, expect 403 @@ -345,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of self, expect 200 @@ -365,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.INVITE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( @@ -373,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # It is always valid to LEAVE if you've already left (currently.) @@ -382,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember @@ -399,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.BAN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -409,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to invite: Must never happen. @@ -418,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.INVITE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -428,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase): src=other, targ=other, membership=Membership.JOIN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -438,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to knock: Must never happen. @@ -447,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.KNOCK, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -457,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -467,10 +474,53 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.LEAVE, - expect_code=200, + expect_code=HTTPStatus.OK, ) +class RoomStateTestCase(RoomBase): + """Tests /rooms/$room_id/state.""" + + user_id = "@sid1:red" + + def test_get_state_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_state_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/state" % room_id, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertCountEqual( + [state_event["type"] for state_event in channel.json_list], + { + "m.room.create", + "m.room.power_levels", + "m.room.join_rules", + "m.room.member", + "m.room.history_visibility", + }, + ) + + def test_get_state_event_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state/$event_type` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_state_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertEqual(channel.json_body, {"membership": "join"}) + + class RoomsMemberListTestCase(RoomBase): """Tests /rooms/$room_id/members/list REST events.""" @@ -481,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self) -> None: channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self) -> None: room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self) -> None: """ @@ -501,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -510,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self) -> None: """ @@ -523,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self) -> None: """ @@ -544,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -563,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self) -> None: room_creator = "@some_other_guy:red" @@ -579,17 +629,73 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + + def test_get_member_list_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_get_member_list_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/members" % room_id, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + channel.json_body["chunk"][0].items(), + ) + + def test_get_member_list_with_at_token_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request.""" + room_id = self.helper.create_room_as(self.user_id) + + # first sync to get an at token + channel = self.make_request("GET", "/sync") + self.assertEqual(HTTPStatus.OK, channel.code) + sync_token = channel.json_body["next_batch"] + + channel = make_request_with_cancellation_test( + "test_get_member_list_with_at_token_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/members?at=%s" % (room_id, sync_token), + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + channel.json_body["chunk"][0].items(), + ) class RoomsCreateTestCase(RoomBase): @@ -601,19 +707,34 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(44, channel.resource_usage.db_txn_count) + + def test_post_room_initial_state(self) -> None: + # POST with initial_state config key, expect new room id + channel = self.make_request( + "POST", + "/createRoom", + b'{"initial_state":[{"type": "m.bridge", "content": {}}]}', + ) + + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertTrue("room_id" in channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(50, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self) -> None: # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self) -> None: @@ -621,16 +742,16 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self) -> None: # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) def test_post_room_invitees_invalid_mxid(self) -> None: # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -638,7 +759,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self) -> None: @@ -649,20 +770,18 @@ class RoomsCreateTestCase(RoomBase): # Build the request's content. We use local MXIDs because invites over federation # are more difficult to mock. - content = json.dumps( - { - "invite": [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - } - ).encode("utf8") + content = { + "invite": [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + } # Test that the invites are correctly ratelimited. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) self.assertEqual( "Cannot invite so many users at once", channel.json_body["error"], @@ -675,11 +794,13 @@ class RoomsCreateTestCase(RoomBase): # Test that the invites aren't ratelimited anymore. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. + + In this test, we use the deprecated API in which callbacks return a bool. """ async def user_may_join_room( @@ -697,10 +818,55 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly bypassed + when creating a new room. + + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + """ + + async def user_may_join_room_codes( + mxid: str, + room_id: str, + is_invite: bool, + ) -> Codes: + return Codes.CONSENT_NOT_GIVEN + + join_mock = Mock(side_effect=user_may_join_room_codes) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + self.assertEqual(join_mock.call_count, 0) + + # Now change the return value of the callback to deny any join. Since we're + # creating the room, despite the return value, we should be able to join. + async def user_may_join_room_tuple( + mxid: str, + room_id: str, + is_invite: bool, + ) -> Tuple[Codes, dict]: + return Codes.INCOMPATIBLE_ROOM_VERSION, {} + + join_mock.side_effect = user_may_join_room_tuple + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual(join_mock.call_count, 0) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -715,54 +881,68 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self) -> None: # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_topic(self) -> None: # nothing should be there channel = self.make_request("GET", self.path) - self.assertEqual(404, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self) -> None: # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -778,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -802,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_members_self(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % ( @@ -813,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEqual(expected_response, channel.json_body) @@ -831,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self) -> None: @@ -850,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) @@ -911,9 +1105,11 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. + + This test uses the deprecated API, in which callbacks return booleans. """ # Register a dummy callback. Make it allow all room joins for now. @@ -926,6 +1122,8 @@ class RoomJoinTestCase(RoomBase): ) -> bool: return return_value + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) @@ -966,7 +1164,92 @@ class RoomJoinTestCase(RoomBase): # Now make the callback deny all room joins, and check that a join actually fails. return_value = False - self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + self.helper.join( + self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2 + ) + + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly called + and blocks room joins when needed. + + This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`. + """ + + # Register a dummy callback. Make it allow all room joins for now. + return_value: Union[ + Literal["NOT_SPAM"], Tuple[Codes, dict], Codes + ] = synapse.module_api.NOT_SPAM + + async def user_may_join_room( + userid: str, + room_id: str, + is_invited: bool, + ) -> Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes]: + return return_value + + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) + + # Join a first room, without being invited to it. + self.helper.join(self.room1, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room1, + False, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Join a second room, this time with an invite for it. + self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1) + self.helper.join(self.room2, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room2, + True, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + # We pick an arbitrary Codes rather than the default `Codes.FORBIDDEN`. + return_value = Codes.CONSENT_NOT_GIVEN + self.helper.invite(self.room3, self.user1, self.user2, tok=self.tok1) + self.helper.join( + self.room3, + self.user2, + expect_code=HTTPStatus.FORBIDDEN, + expect_errcode=return_value, + tok=self.tok2, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + # As above, with the experimental extension that lets us return dictionaries. + return_value = (Codes.BAD_ALIAS, {"another_field": "12345"}) + self.helper.join( + self.room3, + self.user2, + expect_code=HTTPStatus.FORBIDDEN, + expect_errcode=return_value[0], + tok=self.tok2, + expect_additional_fields=return_value[1], + ) class RoomJoinRatelimitTestCase(RoomBase): @@ -1016,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1081,40 +1364,153 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_messages_sent(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + + @parameterized.expand( + [ + # Allow + param( + name="NOT_SPAM", + value="NOT_SPAM", + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + param( + name="False", + value=False, + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + # Block + param( + name="scalene string", + value="ANY OTHER STRING", + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_FORBIDDEN"}, + ), + param( + name="True", + value=True, + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_FORBIDDEN"}, + ), + param( + name="Code", + value=Codes.LIMIT_EXCEEDED, + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_LIMIT_EXCEEDED"}, + ), + param( + name="Tuple", + value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}), + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={ + "errcode": "M_SERVER_NOT_TRUSTED", + "additional_field": "12345", + }, + ), + ] + ) + def test_spam_checker_check_event_for_spam( + self, + name: str, + value: Union[str, bool, Codes, Tuple[Codes, JsonDict]], + expected_code: int, + expected_fields: dict, + ) -> None: + class SpamCheck: + mock_return_value: Union[ + str, bool, Codes, Tuple[Codes, JsonDict], bool + ] = "NOT_SPAM" + mock_content: Optional[JsonDict] = None + + async def check_event_for_spam( + self, + event: synapse.events.EventBase, + ) -> Union[str, Codes, Tuple[Codes, JsonDict], bool]: + self.mock_content = event.content + return self.mock_return_value + + spam_checker = SpamCheck() + + self.hs.get_spam_checker()._check_event_for_spam_callbacks.append( + spam_checker.check_event_for_spam + ) + + # Inject `value` as mock_return_value + spam_checker.mock_return_value = value + path = "/rooms/%s/send/m.room.message/check_event_for_spam_%s" % ( + urlparse.quote(self.room_id), + urlparse.quote(name), + ) + body = "test-%s" % name + content = '{"body":"%s","msgtype":"m.text"}' % body + channel = self.make_request("PUT", path, content) + + # Check that the callback has witnessed the correct event. + self.assertIsNotNone(spam_checker.mock_content) + if ( + spam_checker.mock_content is not None + ): # Checked just above, but mypy doesn't know about that. + self.assertEqual( + spam_checker.mock_content["body"], body, spam_checker.mock_content + ) + + # Check that we have the correct result. + self.assertEqual(expected_code, channel.code, msg=channel.result["body"]) + for expected_key, expected_value in expected_fields.items(): + self.assertEqual( + channel.json_body.get(expected_key, None), + expected_value, + "Field %s absent or invalid " % expected_key, + ) class RoomPowerLevelOverridesTestCase(RoomBase): @@ -1239,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_normal_user_can_not_post_state_event(self) -> None: # Given I am a normal member of a room @@ -1253,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed because state events require PL>=50 - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " "user_level (0) < send_level (50)", @@ -1280,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1308,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1336,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (1)", @@ -1367,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): # Then I am not allowed because the public_chat config does not # affect this room, because this room is a private_chat - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (50)", @@ -1386,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase): def test_initial_sync(self) -> None: channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(self.room_id, channel.json_body["room_id"]) self.assertEqual("join", channel.json_body["membership"]) @@ -1429,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1440,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1479,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) @@ -1507,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) @@ -1524,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) @@ -1652,14 +2048,97 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): def test_restricted_no_auth(self) -> None: channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) def test_restricted_auth(self) -> None: self.register_user("user", "pass") tok = self.login("user", "pass") channel = self.make_request("GET", self.url, access_token=tok) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + +class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + + config = self.default_config() + config["allow_public_rooms_without_auth"] = True + self.hs = self.setup_test_homeserver(config=config) + self.url = b"/_matrix/client/r0/publicRooms" + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + user = self.register_user("alice", "pass") + self.token = self.login(user, "pass") + + # Create a room + self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=self.token, + ) + # Create a space + self.helper.create_room_as( + user, + is_public=True, + extra_content={ + "visibility": "public", + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}, + }, + tok=self.token, + ) + + def make_public_rooms_request( + self, room_types: Union[List[Union[str, None]], None] + ) -> Tuple[List[Dict[str, Any]], int]: + channel = self.make_request( + "POST", + self.url, + {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}, + self.token, + ) + chunk = channel.json_body["chunk"] + count = channel.json_body["total_room_count_estimate"] + + self.assertEqual(len(chunk), count) + + return chunk, count + + def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None: + chunk, count = self.make_public_rooms_request(None) + + self.assertEqual(count, 2) + + def test_returns_only_rooms_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request([None]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("room_type", None), None) + + def test_returns_only_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space"]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("room_type", None), "m.space") + + def test_returns_both_rooms_and_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space", None]) + + self.assertEqual(count, 2) + + def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None: + chunk, count = self.make_public_rooms_request([]) + + self.assertEqual(count, 2) class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): @@ -1686,7 +2165,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): "Simple test for searching rooms over federation" self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1694,7 +2173,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined] "testserv", @@ -1711,11 +2190,11 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # The `get_public_rooms` should be called again if the first call fails # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] - HttpResponseException(404, "Not Found", b""), + HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), make_awaitable({}), ) - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1723,7 +2202,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined] [ @@ -1769,21 +2248,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): # Set a profile for the test user self.displayname = "test user" - data = {"displayname": self.displayname} - request_data = json.dumps(data) + request_data = {"displayname": self.displayname} channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) def test_per_room_profile_forbidden(self) -> None: - data = {"membership": "join", "displayname": "other test user"} - request_data = json.dumps(data) + request_data = {"membership": "join", "displayname": "other test user"} channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" @@ -1791,7 +2268,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1799,7 +2276,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) @@ -1833,7 +2310,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1847,7 +2324,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1861,7 +2338,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1875,7 +2352,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1887,7 +2364,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1899,7 +2376,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1918,7 +2395,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -1930,7 +2407,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ), access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_content = channel.json_body @@ -1978,7 +2455,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2008,7 +2485,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2043,7 +2520,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2123,16 +2600,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_labels(self) -> None: """Test that we can filter by a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2160,16 +2635,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2209,16 +2682,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /search request. """ - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2391,7 +2862,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["chunk"] @@ -2496,7 +2967,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2562,7 +3033,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=invited_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2663,8 +3134,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -2693,8 +3163,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -2720,7 +3189,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - json.dumps(content), + content, access_token=self.room_owner_tok, ) self.assertEqual(channel.code, expected_code, channel.result) @@ -2845,11 +3314,16 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_threepid_invite_spamcheck(self) -> None: + def test_threepid_invite_spamcheck_deprecated(self) -> None: + """ + Test allowing/blocking threepid invites with a spam-check module. + + In this test, we use the deprecated API in which callbacks return a bool. + """ # Mock a few functions to prevent the test from failing due to failing to talk to - # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we + # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable(0)) + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_identity_handler().lookup_3pid = Mock( return_value=make_awaitable(None), @@ -2901,3 +3375,107 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() + + def test_threepid_invite_spamcheck(self) -> None: + """ + Test allowing/blocking threepid invites with a spam-check module. + + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.""" + # Mock a few functions to prevent the test from failing due to failing to talk to + # a remote IS. We keep the mock for make_and_store_3pid_invite around so we + # can check its call_count later on during the test. + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock + self.hs.get_identity_handler().lookup_3pid = Mock( + return_value=make_awaitable(None), + ) + + # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it + # allow everything for now. + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + mock = Mock( + return_value=make_awaitable(synapse.module_api.NOT_SPAM), + spec=lambda *x: None, + ) + self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) + + # Send a 3PID invite into the room and check that it succeeded. + email_to_invite = "teresa@example.com" + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Check that the callback was called with the right params. + mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) + + # Check that the call to send the invite was made. + make_invite_mock.assert_called_once() + + # Now change the return value of the callback to deny any invite and test that + # we can't send the invite. We pick an arbitrary error code to be able to check + # that the same code has been returned + mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN) + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() + + # Run variant with `Tuple[Codes, dict]`. + mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"})) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT) + self.assertEqual(channel.json_body["field"], "value") + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() + + def test_400_missing_param_without_id_access_token(self) -> None: + """ + Test that a 3pid invite request returns 400 M_MISSING_PARAM + if we do not include id_access_token. + """ + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "medium": "email", + "address": "teresa@example.com", + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM") diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a28..c807a37bc2 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import ( room_upgrade_rest_servlet, ) from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import UserID, create_requester from synapse.util import Clock from tests import unittest @@ -97,7 +97,12 @@ class RoomTestCase(_ShadowBannedBase): channel = self.make_request( "POST", "/rooms/%s/invite" % (room_id,), - {"id_server": "test", "medium": "email", "address": "test@test.test"}, + { + "id_server": "test", + "medium": "email", + "address": "test@test.test", + "id_access_token": "anytoken", + }, access_token=self.banned_access_token, ) self.assertEqual(200, channel.code, channel.result) @@ -275,7 +280,7 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, + create_requester(self.banned_user_id), room_id, "m.room.member", self.banned_user_id, @@ -310,7 +315,7 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, + create_requester(self.banned_user_id), room_id, "m.room.member", self.banned_user_id, diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index e3efd1f1b0..0af643ecd9 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import ( KnockingStrippedStateEventHelperMixin, ) from tests.server import TimedOutException -from tests.unittest import override_config class FilterTestCase(unittest.HomeserverTestCase): @@ -390,6 +389,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): sync.register_servlets, ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + + return self.setup_test_homeserver(config=config) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -408,7 +412,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Join the second user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - @override_config({"experimental_features": {"msc2285_enabled": True}}) def test_private_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -416,7 +419,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok2, ) @@ -425,7 +428,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's private read receipt self.assertIsNone(self._get_read_receipt()) - @override_config({"experimental_features": {"msc2285_enabled": True}}) def test_public_receipt_can_override_private(self) -> None: """ Sending a public read receipt to the same event which has a private read @@ -456,7 +458,6 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we did override the private read receipt self.assertNotEqual(self._get_read_receipt(), None) - @override_config({"experimental_features": {"msc2285_enabled": True}}) def test_private_receipt_cannot_override_public(self) -> None: """ Sending a private read receipt to the same event which has a public read @@ -543,7 +544,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): config = super().default_config() config["experimental_features"] = { "msc2654_enabled": True, - "msc2285_enabled": True, } return config @@ -606,11 +606,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8") channel = self.make_request( "POST", f"/rooms/{self.room_id}/read_markers", - body, + {ReceiptTypes.READ: res["event_id"]}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -625,7 +624,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Send a read receipt to tell the server we've read the latest event. channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", {}, access_token=self.tok, ) @@ -701,7 +700,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(5) res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) - # Make sure both m.read and org.matrix.msc2285.read.private advance + # Make sure both m.read and m.read.private advance channel = self.make_request( "POST", f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", @@ -713,16 +712,21 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) self._check_unread_count(0) - # We test for both receipt types that influence notification counts - @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]) - def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None: + # We test for all three receipt types that influence notification counts + @parameterized.expand( + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ] + ) + def test_read_receipts_only_go_down(self, receipt_type: str) -> None: # Join the new user self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @@ -733,18 +737,18 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Read last event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res2['event_id']}", {}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) self._check_unread_count(0) - # Make sure neither m.read nor org.matrix.msc2285.read.private make the + # Make sure neither m.read nor m.read.private make the # read receipt go up to an older event channel = self.make_request( "POST", - f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res1['event_id']}", {}, access_token=self.tok, ) @@ -949,3 +953,24 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase): self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"]) self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"]) + + def test_incremental_sync(self) -> None: + """Tests that activity in the room is properly filtered out of incremental + syncs. + """ + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + next_batch = channel.json_body["next_batch"] + + self.helper.send(self.excluded_room_id, tok=self.tok) + self.helper.send(self.included_room_id, tok=self.tok) + + channel = self.make_request( + "GET", + f"/sync?since={next_batch}", + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"]) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 5eb0f243f7..3325d43a2f 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError from synapse.api.room_versions import RoomVersion +from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase -from synapse.events.snapshot import EventContext from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin from synapse.rest.client import account, login, profile, room @@ -113,14 +113,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth( - origin: str, - event: EventBase, - context: EventContext, - *args: Any, - **kwargs: Any, - ) -> EventContext: - return context + async def _check_event_auth(origin: Any, event: Any, context: Any) -> None: + pass hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] @@ -161,7 +155,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) callback.assert_called_once() @@ -179,7 +173,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, 403, channel.result) def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None: """ @@ -192,12 +186,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): """ class NastyHackException(SynapseError): - def error_dict(self) -> JsonDict: + def error_dict(self, config: Optional[HomeServerConfig]) -> JsonDict: """ This overrides SynapseError's `error_dict` to nastily inject JSON into the error response. """ - result = super().error_dict() + result = super().error_dict(config) result["nasty"] = "very" return result @@ -217,7 +211,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): access_token=self.tok, ) # Check the error code - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual(channel.code, 429, channel.result) # Check the JSON body has had the `nasty` key injected self.assertEqual( channel.json_body, @@ -266,7 +260,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): {"x": "x"}, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] # ... and check that it got modified @@ -275,7 +269,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") @@ -304,7 +298,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) orig_event_id = channel.json_body["event_id"] channel = self.make_request( @@ -321,7 +315,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) edited_event_id = channel.json_body["event_id"] # ... and check that they both got modified @@ -330,7 +324,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) ev = channel.json_body self.assertEqual(ev["content"]["body"], "ORIGINAL BODY") @@ -339,7 +333,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) ev = channel.json_body self.assertEqual(ev["content"]["body"], "EDITED BODY") @@ -385,7 +379,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): }, access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) event_id = channel.json_body["event_id"] @@ -394,7 +388,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, 200, channel.result) self.assertIn("foo", channel.json_body["content"].keys()) self.assertEqual(channel.json_body["content"]["foo"], "bar") diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 98c1039d33..5e7bf97482 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py
@@ -48,10 +48,14 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.helper.join(self.room_id, self.other, tok=self.other_token) def _upgrade_room( - self, token: Optional[str] = None, room_id: Optional[str] = None + self, + token: Optional[str] = None, + room_id: Optional[str] = None, + expire_cache: bool = True, ) -> FakeChannel: - # We never want a cached response. - self.reactor.advance(5 * 60 + 1) + if expire_cache: + # We don't want a cached response. + self.reactor.advance(5 * 60 + 1) if room_id is None: room_id = self.room_id @@ -72,9 +76,24 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) - def test_not_in_room(self) -> None: + new_room_id = channel.json_body["replacement_room"] + + # Check that the tombstone event points to the new room. + tombstone_event = self.get_success( + self.hs.get_storage_controllers().state.get_current_state_event( + self.room_id, EventTypes.Tombstone, "" + ) + ) + self.assertIsNotNone(tombstone_event) + self.assertEqual(new_room_id, tombstone_event.content["replacement_room"]) + + # Check that the new room exists. + room = self.get_success(self.store.get_room(new_room_id)) + self.assertIsNotNone(room) + + def test_never_in_room(self) -> None: """ - Upgrading a room should work fine. + A user who has never been in the room cannot upgrade the room. """ # The user isn't in the room. roomless = self.register_user("roomless", "pass") @@ -83,6 +102,16 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(roomless_token) self.assertEqual(403, channel.code, channel.result) + def test_left_room(self) -> None: + """ + A user who is no longer in the room cannot upgrade the room. + """ + # Remove the user from the room. + self.helper.leave(self.room_id, self.creator, tok=self.creator_token) + + channel = self._upgrade_room(self.creator_token) + self.assertEqual(403, channel.code, channel.result) + def test_power_levels(self) -> None: """ Another user can upgrade the room if their power level is increased. @@ -297,3 +326,47 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.assertEqual( create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type ) + + def test_second_upgrade_from_same_user(self) -> None: + """A second room upgrade from the same user is deduplicated.""" + channel1 = self._upgrade_room() + self.assertEqual(200, channel1.code, channel1.result) + + channel2 = self._upgrade_room(expire_cache=False) + self.assertEqual(200, channel2.code, channel2.result) + + self.assertEqual( + channel1.json_body["replacement_room"], + channel2.json_body["replacement_room"], + ) + + def test_second_upgrade_after_delay(self) -> None: + """A second room upgrade is not deduplicated after some time has passed.""" + channel1 = self._upgrade_room() + self.assertEqual(200, channel1.code, channel1.result) + + channel2 = self._upgrade_room(expire_cache=True) + self.assertEqual(200, channel2.code, channel2.result) + + self.assertNotEqual( + channel1.json_body["replacement_room"], + channel2.json_body["replacement_room"], + ) + + def test_second_upgrade_from_different_user(self) -> None: + """A second room upgrade from a different user is blocked.""" + channel = self._upgrade_room() + self.assertEqual(200, channel.code, channel.result) + + channel = self._upgrade_room(self.other_token, expire_cache=False) + self.assertEqual(400, channel.code, channel.result) + + def test_first_upgrade_does_not_block_second(self) -> None: + """A second room upgrade is not blocked when a previous upgrade attempt was not + allowed. + """ + channel = self._upgrade_room(self.other_token) + self.assertEqual(403, channel.code, channel.result) + + channel = self._upgrade_room(expire_cache=False) + self.assertEqual(200, channel.code, channel.result) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index a0788b1bb0..dd26145bf8 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -41,6 +41,7 @@ from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict @@ -135,11 +136,11 @@ class RestHelper: self.site, "POST", path, - json.dumps(content).encode("utf8"), + content, custom_headers=custom_headers, ) - assert channel.result["code"] == b"%d" % expect_code, channel.result + assert channel.code == expect_code, channel.result self.auth_user_id = temp_id if expect_code == HTTPStatus.OK: @@ -171,6 +172,8 @@ class RestHelper: expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, + expect_errcode: Optional[Codes] = None, + expect_additional_fields: Optional[dict] = None, ) -> None: self.change_membership( room=room, @@ -180,6 +183,8 @@ class RestHelper: appservice_user_id=appservice_user_id, membership=Membership.JOIN, expect_code=expect_code, + expect_errcode=expect_errcode, + expect_additional_fields=expect_additional_fields, ) def knock( @@ -205,14 +210,12 @@ class RestHelper: self.site, "POST", path, - json.dumps(data).encode("utf8"), + data, ) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -263,6 +266,7 @@ class RestHelper: appservice_user_id: Optional[str] = None, expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, + expect_additional_fields: Optional[dict] = None, ) -> None: """ Send a membership state event into a room. @@ -303,14 +307,12 @@ class RestHelper: self.site, "PUT", path, - json.dumps(data).encode("utf8"), + data, ) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -323,6 +325,21 @@ class RestHelper: channel.result["body"], ) + if expect_additional_fields is not None: + for expect_key, expect_value in expect_additional_fields.items(): + assert expect_key in channel.json_body, "Expected field %s, got %s" % ( + expect_key, + channel.json_body, + ) + assert ( + channel.json_body[expect_key] == expect_value + ), "Expected: %s at %s, got: %s, resp: %s" % ( + expect_value, + expect_key, + channel.json_body[expect_key], + channel.json_body, + ) + self.auth_user_id = temp_id def send( @@ -371,15 +388,13 @@ class RestHelper: self.site, "PUT", path, - json.dumps(content or {}).encode("utf8"), + content or {}, custom_headers=custom_headers, ) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -428,11 +443,9 @@ class RestHelper: channel = make_request(self.hs.get_reactor(), self.site, method, path, content) - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], ) @@ -524,7 +537,7 @@ class RestHelper: assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, - int(channel.result["code"]), + channel.code, channel.result["body"], )