diff options
Diffstat (limited to 'tests')
37 files changed, 1186 insertions, 816 deletions
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 18649c2c05..c86f783c5b 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase): # Check that we get rate limited after using that token. self.assertFalse(consume_at(11.1)) + + def test_record_action_which_doesnt_fill_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe two actions, leaving room in the bucket for one more. + limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0) + + # We should be able to take a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertTrue(success) + + # ... but not two. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + def test_record_action_which_fills_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe three actions, filling up the bucket. + limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0) + + # We should be unable to take a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + # If we wait 10 seconds to leak a token, we should be able to take one action... + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertTrue(success) + + # ... but not two. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertFalse(success) + + def test_record_action_which_overfills_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe four actions, exceeding the bucket. + limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0) + + # We should be prevented from taking a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + # If we wait 10 seconds to leak a token, we should be unable to take an action + # because the bucket is still full. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertFalse(success) + + # But after another 10 seconds we leak a second token, giving us room for + # action. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=20.0) + ) + self.assertTrue(success) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9f1115dd23..c6dd99316a 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError @@ -50,7 +51,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) @@ -62,7 +63,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 268a48d7ba..d2bda07198 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -45,7 +45,7 @@ class FederationClientTest(FederatingHomeserverTestCase): # mock up some events to use in the response. # In real life, these would have things in `prev_events` and `auth_events`, but that's # a bit annoying to mock up, and the code under test doesn't care, so we don't bother. - create_event_dict = self.add_hashes_and_signatures( + create_event_dict = self.add_hashes_and_signatures_from_other_server( { "room_id": test_room_id, "type": "m.room.create", @@ -57,7 +57,7 @@ class FederationClientTest(FederatingHomeserverTestCase): "origin_server_ts": 500, } ) - member_event_dict = self.add_hashes_and_signatures( + member_event_dict = self.add_hashes_and_signatures_from_other_server( { "room_id": test_room_id, "type": "m.room.member", @@ -69,7 +69,7 @@ class FederationClientTest(FederatingHomeserverTestCase): "origin_server_ts": 600, } ) - pl_event_dict = self.add_hashes_and_signatures( + pl_event_dict = self.add_hashes_and_signatures_from_other_server( { "room_id": test_room_id, "type": "m.room.power_levels", diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 413b3c9426..3a6ef221ae 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from parameterized import parameterized @@ -20,7 +21,6 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin @@ -59,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, ) - self.assertEqual(400, channel.code, channel.result) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") @@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,) ) - self.assertEqual(403, channel.code, channel.result) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -148,13 +148,13 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): tok2 = self.login("fozzie", "bear") self.helper.join(self._room_id, second_member_user_id, tok=tok2) - def _make_join(self, user_id) -> JsonDict: + def _make_join(self, user_id: str) -> JsonDict: channel = self.make_signed_federation_request( "GET", f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}" f"?ver={DEFAULT_ROOM_VERSION}", ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) return channel.json_body def test_send_join(self): @@ -163,18 +163,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): join_result = self._make_join(joining_user) join_event_dict = join_result["event"] - add_hashes_and_signatures( - KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + self.add_hashes_and_signatures_from_other_server( join_event_dict, - signature_name=self.OTHER_SERVER_NAME, - signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], ) channel = self.make_signed_federation_request( "PUT", f"/_matrix/federation/v2/send_join/{self._room_id}/x", content=join_event_dict, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # we should get complete room state back returned_state = [ @@ -220,18 +218,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): join_result = self._make_join(joining_user) join_event_dict = join_result["event"] - add_hashes_and_signatures( - KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + self.add_hashes_and_signatures_from_other_server( join_event_dict, - signature_name=self.OTHER_SERVER_NAME, - signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], ) channel = self.make_signed_federation_request( "PUT", f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", content=join_event_dict, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # expect a reduced room state returned_state = [ @@ -264,6 +260,67 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) + def test_make_join_respects_room_join_rate_limit(self) -> None: + # In the test setup, two users join the room. Since the rate limiter burst + # count is 3, a new make_join request to the room should be accepted. + + joining_user = "@ronniecorbett:" + self.OTHER_SERVER_NAME + self._make_join(joining_user) + + # Now have a new local user join the room. This saturates the rate limiter + # bucket, so the next make_join should be denied. + new_local_user = self.register_user("animal", "animal") + token = self.login("animal", "animal") + self.helper.join(self._room_id, new_local_user, tok=token) + + joining_user = "@ronniebarker:" + self.OTHER_SERVER_NAME + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self._room_id}/{joining_user}" + f"?ver={DEFAULT_ROOM_VERSION}", + ) + self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body) + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) + def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None: + # Make two make_join requests up front. (These are rate limited, but do not + # contribute to the rate limit.) + join_event_dicts = [] + for i in range(2): + joining_user = f"@misspiggy{i}:{self.OTHER_SERVER_NAME}" + join_result = self._make_join(joining_user) + join_event_dict = join_result["event"] + self.add_hashes_and_signatures_from_other_server( + join_event_dict, + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + ) + join_event_dicts.append(join_event_dict) + + # In the test setup, two users join the room. Since the rate limiter burst + # count is 3, the first send_join should be accepted... + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self._room_id}/join0", + content=join_event_dicts[0], + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # ... but the second should be denied. + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self._room_id}/join1", + content=join_event_dicts[1], + ) + self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body) + + # NB: we could write a test which checks that the send_join event is seen + # by other workers over replication, and that they update their rate limit + # buckets accordingly. I'm going to assume that the join event gets sent over + # replication, at which point the tests.handlers.room_member test + # test_local_users_joining_on_another_worker_contribute_to_rate_limit + # is probably sufficient to reassure that the bucket is updated. + def _create_acl_event(content): return make_event_from_dict( diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index d21c11b716..0d048207b7 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict +from http import HTTPStatus from typing import Dict, List from synapse.api.constants import EventTypes, JoinRules, Membership @@ -255,7 +256,7 @@ class FederationKnockingTestCase( RoomVersions.V7.identifier, ), ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Note: We don't expect the knock membership event to be sent over federation as # part of the stripped room state, as the knocking homeserver already has that @@ -293,7 +294,7 @@ class FederationKnockingTestCase( % (room_id, signed_knock_event.event_id), signed_knock_event_json, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Check that we got the stripped room state in return room_state_events = channel.json_body["knock_state_events"] diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index d96d5aa138..b17af2725b 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -50,7 +50,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastores.return_value = Mock(main=self.mock_store) - self.mock_store.get_received_ts.return_value = make_awaitable(0) + self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( None @@ -76,9 +76,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [])), - make_awaitable((1, [event])), + self.mock_store.get_all_new_events_stream.side_effect = [ + make_awaitable((0, [], {})), + make_awaitable((1, [event], {event.event_id: 0})), ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) @@ -95,8 +95,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [event])), + self.mock_store.get_all_new_events_stream.side_effect = [ + make_awaitable((0, [event], {event.event_id: 0})), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -112,8 +112,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [event])), + self.mock_store.get_all_new_events_stream.side_effect = [ + make_awaitable((0, [event], {event.event_id: 0})), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 9b9c11fab7..8a0bb91f40 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, cast +from typing import cast from unittest import TestCase from twisted.test.proto_helpers import MemoryReactor @@ -50,8 +50,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_storage_controller = hs.get_storage_controllers().state - self._event_auth_handler = hs.get_event_auth_handler() return hs def test_exchange_revoked_invite(self) -> None: @@ -256,7 +254,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ] for _ in range(0, 8): event = make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "origin_server_ts": 1, "type": "m.room.message", @@ -314,142 +312,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) self.get_success(d) - def test_backfill_floating_outlier_membership_auth(self) -> None: - """ - As the local homeserver, check that we can properly process a federated - event from the OTHER_SERVER with auth_events that include a floating - membership event from the OTHER_SERVER. - - Regression test, see #10439. - """ - OTHER_SERVER = "otherserver" - OTHER_USER = "@otheruser:" + OTHER_SERVER - - # create the room - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - room_id = self.helper.create_room_as( - room_creator=user_id, - is_public=True, - tok=tok, - extra_content={ - "preset": "public_chat", - }, - ) - room_version = self.get_success(self.store.get_room_version(room_id)) - - prev_event_ids = self.get_success(self.store.get_prev_events_for_room(room_id)) - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = self.get_success(self.store.get_max_depth_of(prev_event_ids)) - # mapping from (type, state_key) -> state_event_id - assert most_recent_prev_event_id is not None - prev_state_map = self.get_success( - self.state_storage_controller.get_state_ids_for_event( - most_recent_prev_event_id - ) - ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - auth_events = list( - self.get_success(self.store.get_events(auth_event_ids)).values() - ) - - # build a floating outlier member state event - fake_prev_event_id = "$" + random_string(43) - member_event_dict = { - "type": EventTypes.Member, - "content": { - "membership": "join", - }, - "state_key": OTHER_USER, - "room_id": room_id, - "sender": OTHER_USER, - "depth": most_recent_prev_event_depth, - "prev_events": [fake_prev_event_id], - "origin_server_ts": self.clock.time_msec(), - "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}}, - } - builder = self.hs.get_event_builder_factory().for_room_version( - room_version, member_event_dict - ) - member_event = self.get_success( - builder.build( - prev_event_ids=member_event_dict["prev_events"], - auth_event_ids=self._event_auth_handler.compute_auth_events( - builder, - prev_state_map, - for_verification=False, - ), - depth=member_event_dict["depth"], - ) - ) - # Override the signature added from "test" homeserver that we created the event with - member_event.signatures = member_event_dict["signatures"] - - # Add the new member_event to the StateMap - updated_state_map = dict(prev_state_map) - updated_state_map[ - (member_event.type, member_event.state_key) - ] = member_event.event_id - auth_events.append(member_event) - - # build and send an event authed based on the member event - message_event_dict = { - "type": EventTypes.Message, - "content": {}, - "room_id": room_id, - "sender": OTHER_USER, - "depth": most_recent_prev_event_depth, - "prev_events": prev_event_ids.copy(), - "origin_server_ts": self.clock.time_msec(), - "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}}, - } - builder = self.hs.get_event_builder_factory().for_room_version( - room_version, message_event_dict - ) - message_event = self.get_success( - builder.build( - prev_event_ids=message_event_dict["prev_events"], - auth_event_ids=self._event_auth_handler.compute_auth_events( - builder, - updated_state_map, - for_verification=False, - ), - depth=message_event_dict["depth"], - ) - ) - # Override the signature added from "test" homeserver that we created the event with - message_event.signatures = message_event_dict["signatures"] - - # Stub the /event_auth response from the OTHER_SERVER - async def get_event_auth( - destination: str, room_id: str, event_id: str - ) -> List[EventBase]: - return [ - event_from_pdu_json(ae.get_pdu_json(), room_version=room_version) - for ae in auth_events - ] - - self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment] - - with LoggingContext("receive_pdu"): - # Fake the OTHER_SERVER federating the message event over to our local homeserver - d = run_in_background( - self.hs.get_federation_event_handler().on_receive_pdu, - OTHER_SERVER, - message_event, - ) - self.get_success(d) - - # Now try and get the events on our local homeserver - stored_event = self.get_success( - self.store.get_event(message_event.event_id, allow_none=True) - ) - self.assertTrue(stored_event is not None) - @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 4b1a8f04db..51c8dd6498 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -104,7 +104,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # mock up a load of state events which we are missing state_events = [ make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "type": "test_state_type", "state_key": f"state_{i}", @@ -131,7 +131,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # Depending on the test, we either persist this upfront (as an outlier), # or let the server request it. prev_event = make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "type": "test_regular_type", "room_id": room_id, @@ -165,7 +165,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # mock up a regular event to pass into _process_pulled_event pulled_event = make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "type": "test_regular_type", "room_id": room_id, diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 82b3bb3b73..4c62449c89 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -14,6 +14,7 @@ """Tests for the password_auth_provider interface""" +from http import HTTPStatus from typing import Any, Type, Union from unittest.mock import Mock @@ -188,14 +189,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # check_password must return an awaitable mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._send_password_login("u", "p") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") mock_password_provider.reset_mock() # login with mxid should work too channel = self._send_password_login("@u:bz", "p") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@u:bz", channel.json_body["user_id"]) mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") mock_password_provider.reset_mock() @@ -204,7 +205,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # in these cases, but at least we can guard against the API changing # unexpectedly channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"]) mock_password_provider.check_password.assert_called_once_with( "@ USER🙂NAME :test", " pASS😢word " @@ -258,10 +259,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # check_password must return an awaitable mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("u", "p") - self.assertEqual(channel.code, 403, channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@localuser:test", channel.json_body["user_id"]) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @@ -382,7 +383,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("u", "p") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_password.assert_not_called() @override_config(legacy_providers_config(LegacyCustomAuthProvider)) @@ -406,14 +407,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login with missing param should be rejected channel = self._send_login("test.login_type", "u") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) channel = self._send_login("test.login_type", "u", test_field="y") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( "u", "test.login_type", {"test_field": "y"} @@ -427,7 +428,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( " USER🙂NAME ", "test.login_type", {"test_field": " abc "} @@ -510,7 +511,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ("@user:bz", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( "u", "test.login_type", {"test_field": "y"} @@ -549,7 +550,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() @override_config( @@ -584,7 +585,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() @override_config( @@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_password.assert_not_called() @@ -646,13 +647,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) tok1 = channel.json_body["access_token"] channel = self._send_login( "test.login_type", "localuser", test_field="", device_id="dev2" ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # make the initial request which returns a 401 channel = self._delete_device(tok1, "dev2") @@ -721,7 +722,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # password login shouldn't work and should be rejected with a 400 # ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) def test_on_logged_out(self): """Tests that the on_logged_out callback is called when the user logs out.""" @@ -884,7 +885,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): }, access_token=tok, ) - self.assertEqual(channel.code, 403, channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.THREEPID_DENIED, @@ -906,7 +907,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): }, access_token=tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertIn("sid", channel.json_body) m.assert_called_once_with("email", "bar@test.com", registration) @@ -949,12 +950,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "register", {"auth": {"session": session, "type": LoginType.DUMMY}}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) return channel.json_body def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["flows"] def _send_password_login(self, user: str, password: str) -> FakeChannel: diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py new file mode 100644 index 0000000000..254e7e4b80 --- /dev/null +++ b/tests/handlers/test_room_member.py @@ -0,0 +1,290 @@ +from http import HTTPStatus +from unittest.mock import Mock, patch + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +import synapse.rest.client.login +import synapse.rest.client.room +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import LimitExceededError +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import FrozenEventV3 +from synapse.federation.federation_client import SendJoinResult +from synapse.server import HomeServer +from synapse.types import UserID, create_requester +from synapse.util import Clock + +from tests.replication._base import RedisMultiWorkerStreamTestCase +from tests.server import make_request +from tests.test_utils import make_awaitable +from tests.unittest import FederatingHomeserverTestCase, override_config + + +class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + synapse.rest.client.room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_room_member_handler() + + # Create three users. + self.alice = self.register_user("alice", "pass") + self.alice_token = self.login("alice", "pass") + self.bob = self.register_user("bob", "pass") + self.bob_token = self.login("bob", "pass") + self.chris = self.register_user("chris", "pass") + self.chris_token = self.login("chris", "pass") + + # Create a room on this homeserver. Note that this counts as a join: it + # contributes to the rate limter's count of actions + self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) + + self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}" + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None: + # The rate limiter has accumulated one token from Alice's join after the create + # event. + # Try joining the room as Bob. + self.get_success( + self.handler.update_membership( + requester=create_requester(self.bob), + target=UserID.from_string(self.bob), + room_id=self.room_id, + action=Membership.JOIN, + ) + ) + + # The rate limiter bucket is full. A second join should be denied. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ), + LimitExceededError, + ) + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None: + # The rate limiter has accumulated one token from Alice's join after the create + # event. Alice should still be able to change her displayname. + self.get_success( + self.handler.update_membership( + requester=create_requester(self.alice), + target=UserID.from_string(self.alice), + room_id=self.room_id, + action=Membership.JOIN, + content={"displayname": "Alice Cooper"}, + ) + ) + + # Still room in the limiter bucket. Chris's join should be accepted. + self.get_success( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ) + ) + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}}) + def test_remote_joins_contribute_to_rate_limit(self) -> None: + # Join once, to fill the rate limiter bucket. + # + # To do this we have to mock the responses from the remote homeserver. + # We also patch out a bunch of event checks on our end. All we're really + # trying to check here is that remote joins will bump the rate limter when + # they are persisted. + create_event_source = { + "auth_events": [], + "content": { + "creator": f"@creator:{self.OTHER_SERVER_NAME}", + "room_version": self.hs.config.server.default_room_version.identifier, + }, + "depth": 0, + "origin_server_ts": 0, + "prev_events": [], + "room_id": self.intially_unjoined_room_id, + "sender": f"@creator:{self.OTHER_SERVER_NAME}", + "state_key": "", + "type": EventTypes.Create, + } + self.add_hashes_and_signatures_from_other_server( + create_event_source, + self.hs.config.server.default_room_version, + ) + create_event = FrozenEventV3( + create_event_source, + self.hs.config.server.default_room_version, + {}, + None, + ) + + join_event_source = { + "auth_events": [create_event.event_id], + "content": {"membership": "join"}, + "depth": 1, + "origin_server_ts": 100, + "prev_events": [create_event.event_id], + "sender": self.bob, + "state_key": self.bob, + "room_id": self.intially_unjoined_room_id, + "type": EventTypes.Member, + } + add_hashes_and_signatures( + self.hs.config.server.default_room_version, + join_event_source, + self.hs.hostname, + self.hs.signing_key, + ) + join_event = FrozenEventV3( + join_event_source, + self.hs.config.server.default_room_version, + {}, + None, + ) + + mock_make_membership_event = Mock( + return_value=make_awaitable( + ( + self.OTHER_SERVER_NAME, + join_event, + self.hs.config.server.default_room_version, + ) + ) + ) + mock_send_join = Mock( + return_value=make_awaitable( + SendJoinResult( + join_event, + self.OTHER_SERVER_NAME, + state=[create_event], + auth_chain=[create_event], + partial_state=False, + servers_in_room=[], + ) + ) + ) + + with patch.object( + self.handler.federation_handler.federation_client, + "make_membership_event", + mock_make_membership_event, + ), patch.object( + self.handler.federation_handler.federation_client, + "send_join", + mock_send_join, + ), patch( + "synapse.event_auth._is_membership_change_allowed", + return_value=None, + ), patch( + "synapse.handlers.federation_event.check_state_dependent_auth_rules", + return_value=None, + ): + self.get_success( + self.handler.update_membership( + requester=create_requester(self.bob), + target=UserID.from_string(self.bob), + room_id=self.intially_unjoined_room_id, + action=Membership.JOIN, + remote_room_hosts=[self.OTHER_SERVER_NAME], + ) + ) + + # Try to join as Chris. Should get denied. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.intially_unjoined_room_id, + action=Membership.JOIN, + remote_room_hosts=[self.OTHER_SERVER_NAME], + ), + LimitExceededError, + ) + + # TODO: test that remote joins to a room are rate limited. + # Could do this by setting the burst count to 1, then: + # - remote-joining a room + # - immediately leaving + # - trying to remote-join again. + + +class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + synapse.rest.client.room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_room_member_handler() + + # Create three users. + self.alice = self.register_user("alice", "pass") + self.alice_token = self.login("alice", "pass") + self.bob = self.register_user("bob", "pass") + self.bob_token = self.login("bob", "pass") + self.chris = self.register_user("chris", "pass") + self.chris_token = self.login("chris", "pass") + + # Create a room on this homeserver. + # Note that this counts as a + self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) + self.intially_unjoined_room_id = "!example:otherhs" + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + def test_local_users_joining_on_another_worker_contribute_to_rate_limit( + self, + ) -> None: + # The rate limiter has accumulated one token from Alice's join after the create + # event. + self.replicate() + + # Spawn another worker and have bob join via it. + worker_app = self.make_worker_hs( + "synapse.app.generic_worker", extra_config={"worker_name": "other worker"} + ) + worker_site = self._hs_to_site[worker_app] + channel = make_request( + self.reactor, + worker_site, + "POST", + f"/_matrix/client/v3/rooms/{self.room_id}/join", + access_token=self.bob_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # wait for join to arrive over replication + self.replicate() + + # Try to join as Chris on the worker. Should get denied because Alice + # and Bob have both joined the room. + self.get_failure( + worker_app.get_room_member_handler().update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ), + LimitExceededError, + ) + + # Try to join as Chris on the original worker. Should get denied because Alice + # and Bob have both joined the room. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ), + LimitExceededError, + ) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index ecc7cc6461..e3f38fbcc5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -159,7 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear() # The rooms should be excluded from the sync response. diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 230dc76f72..2526136ff8 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -21,7 +21,7 @@ from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.errors import Codes from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room @@ -1130,6 +1130,8 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("guest_access", r) self.assertIn("history_visibility", r) self.assertIn("state_events", r) + self.assertIn("room_type", r) + self.assertIsNone(r["room_type"]) # Check that the correct number of total rooms was returned self.assertEqual(channel.json_body["total_rooms"], total_rooms) @@ -1229,7 +1231,11 @@ class RoomTestCase(unittest.HomeserverTestCase): def test_correct_room_attributes(self) -> None: """Test the correct attributes for a room are returned""" # Create a test room - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id = self.helper.create_room_as( + self.admin_user, + tok=self.admin_user_tok, + extra_content={"creation_content": {"type": RoomTypes.SPACE}}, + ) test_alias = "#test:test" test_room_name = "something" @@ -1306,6 +1312,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id, r["room_id"]) self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_alias, r["canonical_alias"]) + self.assertEqual(RoomTypes.SPACE, r["room_type"]) def test_room_list_sort_order(self) -> None: """Test room list sort ordering. alphabetical name versus number of members, @@ -1630,7 +1637,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("guest_access", channel.json_body) self.assertIn("history_visibility", channel.json_body) self.assertIn("state_events", channel.json_body) - + self.assertIn("room_type", channel.json_body) self.assertEqual(room_id_1, channel.json_body["room_id"]) def test_single_room_devices(self) -> None: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e32aaadb98..12db68d564 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "admin": False}, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # Admin user is not blocked by mau anymore - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1636,6 +1636,41 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(len(pushers), 0) + @override_config( + { + "email": { + "enable_notifs": True, + "notif_for_new_users": True, + "notif_from": "test@example.com", + }, + "public_baseurl": "https://example.com", + } + ) + def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None: + """ + Check that a new regular user is created successfully when they have a msisdn + threepid and email notif_for_new_users is set to True. + """ + url = self.url_prefix % "@bob:test" + + # Create user + body = { + "password": "abc123", + "threepids": [{"medium": "msisdn", "address": "1234567890"}], + } + + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body, + ) + + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"]) + def test_set_password(self) -> None: """ Test setting a new password for another user. @@ -2372,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123"}, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 1f9b65351e..7ae926dc9c 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os import re from email.parser import Parser +from http import HTTPStatus from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock @@ -95,10 +95,8 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") - ) - self.assertEqual(channel.code, 403, channel.result) + channel = self.make_request("POST", "/_matrix/client/r0/login", body) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) def test_basic_password_reset(self) -> None: """Test basic password reset flow""" @@ -347,7 +345,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button @@ -362,7 +360,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -390,7 +388,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): new_password: str, session_id: str, client_secret: str, - expected_code: int = 200, + expected_code: int = HTTPStatus.OK, ) -> None: channel = self.make_request( "POST", @@ -479,16 +477,14 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertEqual(memberships[0].room_id, room_id, memberships) def deactivate(self, user_id: str, tok: str) -> None: - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) @@ -715,7 +711,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -725,7 +723,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email(self) -> None: @@ -747,7 +745,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -756,7 +754,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email_if_disabled(self) -> None: @@ -781,7 +779,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) @@ -817,7 +817,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -827,7 +829,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_no_valid_token(self) -> None: @@ -852,7 +854,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -862,7 +866,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @override_config({"next_link_domain_whitelist": None}) @@ -872,7 +876,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/good/site", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -884,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="some-protocol://abcdefghijklmopqrstuvwxyz", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -895,7 +899,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="file:///host/path", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) @@ -907,28 +911,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link=None, - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.com/some/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.org/some/also/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://bad.example.org/some/bad/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": []}) @@ -940,7 +944,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) def _request_token( @@ -948,7 +952,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): email: str, client_secret: str, next_link: Optional[str] = None, - expect_code: int = 200, + expect_code: int = HTTPStatus.OK, ) -> Optional[str]: """Request a validation token to add an email address to a user's account @@ -993,7 +997,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_error, channel.json_body["error"]) @@ -1002,7 +1008,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -1052,7 +1058,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -1061,7 +1067,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} @@ -1092,7 +1098,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that not providing any MXID raises an error.""" self._test_status( users=None, - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.MISSING_PARAM, ) @@ -1100,7 +1106,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that providing an invalid MXID raises an error.""" self._test_status( users=["bad:test"], - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.INVALID_PARAM, ) @@ -1286,7 +1292,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): def _test_status( self, users: Optional[List[str]], - expected_status_code: int = 200, + expected_status_code: int = HTTPStatus.OK, expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_failures: Optional[List[str]] = None, expected_errcode: Optional[str] = None, diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index 16e7ef41bc..7a88aa2cda 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -97,8 +96,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # We use deliberately a localpart under the length threshold so # that we can make sure that the check is done on the whole alias. - data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -110,8 +108,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # Check with an alias of allowed length. There should already be # a test that ensures it works in test_register.py, but let's be # as cautious as possible here. - data = {"room_alias_name": random_string(5)} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(5)} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -144,8 +141,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # Add an alias for the room, as the appservice alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string() - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", @@ -193,8 +189,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.hs.hostname, ) - data = {"aliases": [self.random_alias(alias_length)]} - request_data = json.dumps(data) + request_data = {"aliases": [self.random_alias(alias_length)]} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok @@ -206,8 +201,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) -> str: alias = self.random_alias(alias_length) url = "/_matrix/client/r0/directory/room/%s" % alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 299b9d21e2..dc17c9d113 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -51,12 +50,11 @@ class IdentityTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] - params = { + request_data = { "id_server": "testis", "medium": "email", "address": "test@example.com", } - request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") channel = self.make_request( b"POST", request_url, request_data, access_token=tok diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index f6efa5fe37..a2958f6959 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import time import urllib.parse +from http import HTTPStatus from typing import Any, Dict, List, Optional from unittest.mock import Mock from urllib.parse import urlencode @@ -261,20 +261,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -288,7 +288,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -296,7 +296,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self._delete_device(access_token_2, "kermit", "monkey", device_id) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], False) @@ -307,7 +307,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) # check it's a UI-Auth fail self.assertEqual( set(channel.json_body.keys()), @@ -330,7 +330,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token=access_token, content={"auth": auth}, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: @@ -341,14 +341,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -367,14 +367,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -399,7 +399,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "/_matrix/client/v3/login", - json.dumps(body).encode("utf8"), + body, custom_headers=None, ) @@ -466,7 +466,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) expected_flow_types = [ "m.login.cas", @@ -494,14 +494,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker channel = self._make_sso_redirect_request(None) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers uri = location_headers[0] # hitting that picker should give us some HTML channel = self.make_request("GET", uri) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # parse the form to check it has fields assumed elsewhere in this class html = channel.result["body"].decode("utf-8") @@ -530,7 +530,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + "&idp=cas", shorthand=False, ) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers cas_uri = location_headers[0] @@ -555,7 +555,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers saml_uri = location_headers[0] @@ -579,7 +579,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers oidc_uri = location_headers[0] @@ -606,7 +606,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) content_type_headers = channel.headers.getRawHeaders("Content-Type") assert content_type_headers self.assertTrue(content_type_headers[-1].startswith("text/html")) @@ -634,7 +634,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): "/login", content={"type": "m.login.token", "token": login_token}, ) - self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") def test_multi_sso_redirect_to_unknown(self) -> None: @@ -643,18 +643,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" channel = self._make_sso_redirect_request("xxx") - self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" channel = self._make_sso_redirect_request("oidc") - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers oidc_uri = location_headers[0] @@ -765,7 +765,7 @@ class CASTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", cas_ticket_url) # Test that the response is HTML. - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) content_type_header_value = "" for header in channel.result.get("headers", []): if header[0] == b"Content-Type": @@ -1246,7 +1246,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ) # that should redirect to the username picker - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers picker_url = location_headers[0] @@ -1290,7 +1290,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ("Content-Length", str(len(content))), ], ) - self.assertEqual(chan.code, 302, chan.result) + self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) location_headers = chan.headers.getRawHeaders("Location") assert location_headers @@ -1300,7 +1300,7 @@ class UsernamePickerTestCase(HomeserverTestCase): path=location_headers[0], custom_headers=[("Cookie", "username_mapping_session=" + session_id)], ) - self.assertEqual(chan.code, 302, chan.result) + self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) location_headers = chan.headers.getRawHeaders("Location") assert location_headers @@ -1325,5 +1325,5 @@ class UsernamePickerTestCase(HomeserverTestCase): "/login", content={"type": "m.login.token", "token": login_token}, ) - self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py index 3a74d2e96c..e19d21d6ee 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -89,7 +88,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_too_short(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request_data = {"username": "kermit", "password": "shorty"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -100,7 +99,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_digit(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request_data = {"username": "kermit", "password": "longerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -111,7 +110,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_symbol(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request_data = {"username": "kermit", "password": "l0ngerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -122,7 +121,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_uppercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request_data = {"username": "kermit", "password": "l0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -133,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_lowercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -144,7 +143,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_compliant(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request_data = {"username": "kermit", "password": "L0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has @@ -161,16 +160,14 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", compliant_password) tok = self.login("kermit", compliant_password) - request_data = json.dumps( - { - "new_password": not_compliant_password, - "auth": { - "password": compliant_password, - "type": LoginType.PASSWORD, - "user": user_id, - }, - } - ) + request_data = { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } channel = self.make_request( "POST", "/_matrix/client/r0/account/password", diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index afb08b2736..071b488cc0 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime -import json import os from typing import Any, Dict, List, Tuple @@ -62,9 +61,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps( - {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = { + "username": "as_user_kermit", + "type": APP_SERVICE_REGISTRATION_TYPE, + } channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data @@ -85,7 +85,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps({"username": "as_user_kermit"}) + request_data = {"username": "as_user_kermit"} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data @@ -95,9 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists - request_data = json.dumps( - {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) @@ -105,14 +103,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self) -> None: - request_data = json.dumps({"username": "kermit", "password": 666}) + request_data = {"username": "kermit", "password": 666} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self) -> None: - request_data = json.dumps({"username": 777, "password": "monkey"}) + request_data = {"username": 777, "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) @@ -121,13 +119,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_user_valid(self) -> None: user_id = "@kermit:test" device_id = "frogfone" - params = { + request_data = { "username": "kermit", "password": "monkey", "device_id": device_id, "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) det_data = { @@ -140,7 +137,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"enable_registration": False}) def test_POST_disabled_registration(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request_data = {"username": "kermit", "password": "monkey"} self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) channel = self.make_request(b"POST", self.url, request_data) @@ -188,13 +185,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self) -> None: for i in range(0, 6): - params = { + request_data = { "username": "kermit" + str(i), "password": "monkey", "device_id": "frogfone", "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) if i == 5: @@ -234,7 +230,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } # Request without auth to get flows and session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one @@ -251,8 +247,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -262,8 +257,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.DUMMY, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, params) det_data = { "user_id": f"@{username}:{self.hs.hostname}", "home_server": self.hs.hostname, @@ -290,7 +284,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Test with token param missing (invalid) @@ -298,21 +292,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.REGISTRATION_TOKEN, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -337,9 +331,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} # Do 2 requests without auth to get two session IDs - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with session1 and check `pending` is 1 @@ -348,9 +342,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Repeat request to make sure pending isn't increased again - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) pending = self.get_success( store.db_pool.simple_select_one_onecol( "registration_tokens", @@ -366,14 +360,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session2, } - channel = self.make_request(b"POST", self.url, json.dumps(params2)) + channel = self.make_request(b"POST", self.url, params2) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check pending=0 and completed=1 res = self.get_success( store.db_pool.simple_select_one( @@ -386,7 +380,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(res["completed"], 1) # Check auth still fails when using token with session2 - channel = self.make_request(b"POST", self.url, json.dumps(params2)) + channel = self.make_request(b"POST", self.url, params2) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -411,7 +405,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) params: JsonDict = {"username": "kermit", "password": "monkey"} # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Check authentication fails with expired token @@ -420,7 +414,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -435,7 +429,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Check authentication succeeds - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -460,9 +454,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do 2 requests without auth to get two session IDs params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with both sessions @@ -471,18 +465,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) params2["auth"] = { "type": LoginType.REGISTRATION_TOKEN, "token": token, "session": session2, } - self.make_request(b"POST", self.url, json.dumps(params2)) + self.make_request(b"POST", self.url, params2) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check `result` of registration token stage for session1 is `True` result1 = self.get_success( @@ -550,7 +544,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do request without auth to get a session ID params: JsonDict = {"username": "kermit", "password": "monkey"} - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Use token @@ -559,7 +553,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - self.make_request(b"POST", self.url, json.dumps(params)) + self.make_request(b"POST", self.url, params) # Delete token self.get_success( @@ -592,9 +586,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "require_at_registration": True, }, "account_threepid_delegates": { - "email": "https://id_server", "msisdn": "https://id_server", }, + "email": {"notif_from": "Synapse <synapse@example.com>"}, } ) def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: @@ -827,8 +821,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = {"user_id": user_id} - request_data = json.dumps(params) + request_data = {"user_id": user_id} channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEqual(channel.result["code"], b"200", channel.result) @@ -845,12 +838,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEqual(channel.result["code"], b"200", channel.result) @@ -870,12 +862,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEqual(channel.result["code"], b"200", channel.result) @@ -1041,16 +1032,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py index 20a259fc43..ad0d0209f7 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json - from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -77,10 +75,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( - "POST", - self.report_path, - json.dumps(data), - access_token=self.other_user_tok, + "POST", self.report_path, data, access_token=self.other_user_tok ) self.assertEqual( response_status, int(channel.result["code"]), msg=channel.result["body"] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index df7ffbe545..c45cb32090 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,6 +18,7 @@ """Tests REST events for /rooms paths.""" import json +from http import HTTPStatus from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from unittest.mock import Mock, call from urllib import parse as urlparse @@ -104,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -112,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -134,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_topic_perms(self) -> None: topic_content = b'{"topic":"My Topic Name"}' @@ -165,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -194,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -220,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room: str, members: Iterable = frozenset(), expect_code: int = 200 @@ -309,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( room=room, src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) def test_joined_permissions(self) -> None: @@ -342,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of other, expect 403 @@ -351,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of self, expect 200 @@ -371,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.INVITE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( @@ -379,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # It is always valid to LEAVE if you've already left (currently.) @@ -388,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember @@ -405,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.BAN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -415,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to invite: Must never happen. @@ -424,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.INVITE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -434,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase): src=other, targ=other, membership=Membership.JOIN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -444,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to knock: Must never happen. @@ -453,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.KNOCK, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -463,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -473,7 +474,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.LEAVE, - expect_code=200, + expect_code=HTTPStatus.OK, ) @@ -493,7 +494,7 @@ class RoomStateTestCase(RoomBase): "/rooms/%s/state" % room_id, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertCountEqual( [state_event["type"] for state_event in channel.json_body], { @@ -516,7 +517,7 @@ class RoomStateTestCase(RoomBase): "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(channel.json_body, {"membership": "join"}) @@ -530,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self) -> None: channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self) -> None: room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self) -> None: """ @@ -550,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -559,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self) -> None: """ @@ -572,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self) -> None: """ @@ -593,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -612,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self) -> None: room_creator = "@some_other_guy:red" @@ -628,17 +629,17 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_get_member_list_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/members` request.""" @@ -651,7 +652,7 @@ class RoomsMemberListTestCase(RoomBase): "/rooms/%s/members" % room_id, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertLessEqual( { @@ -671,7 +672,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] channel = make_request_with_cancellation_test( @@ -682,7 +683,7 @@ class RoomsMemberListTestCase(RoomBase): "/rooms/%s/members?at=%s" % (room_id, sync_token), ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertLessEqual( { @@ -706,10 +707,10 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(37, channel.resource_usage.db_txn_count) + self.assertEqual(44, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -719,21 +720,21 @@ class RoomsCreateTestCase(RoomBase): b'{"initial_state":[{"type": "m.bridge", "content": {}}]}', ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(41, channel.resource_usage.db_txn_count) + self.assertEqual(50, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self) -> None: # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self) -> None: @@ -741,16 +742,16 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self) -> None: # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) def test_post_room_invitees_invalid_mxid(self) -> None: # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -758,7 +759,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self) -> None: @@ -769,20 +770,18 @@ class RoomsCreateTestCase(RoomBase): # Build the request's content. We use local MXIDs because invites over federation # are more difficult to mock. - content = json.dumps( - { - "invite": [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - } - ).encode("utf8") + content = { + "invite": [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + } # Test that the invites are correctly ratelimited. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) self.assertEqual( "Cannot invite so many users at once", channel.json_body["error"], @@ -795,7 +794,7 @@ class RoomsCreateTestCase(RoomBase): # Test that the invites aren't ratelimited anymore. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed @@ -819,7 +818,7 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) @@ -845,7 +844,7 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) @@ -865,7 +864,7 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) @@ -882,54 +881,68 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self) -> None: # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_topic(self) -> None: # nothing should be there channel = self.make_request("GET", self.path) - self.assertEqual(404, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self) -> None: # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -945,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -969,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_members_self(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % ( @@ -980,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEqual(expected_response, channel.json_body) @@ -998,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self) -> None: @@ -1017,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) @@ -1137,7 +1164,9 @@ class RoomJoinTestCase(RoomBase): # Now make the callback deny all room joins, and check that a join actually fails. return_value = False - self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + self.helper.join( + self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2 + ) def test_spam_checker_may_join_room(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called @@ -1205,7 +1234,7 @@ class RoomJoinTestCase(RoomBase): self.helper.join( self.room3, self.user2, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, expect_errcode=return_value, tok=self.tok2, ) @@ -1216,7 +1245,7 @@ class RoomJoinTestCase(RoomBase): self.helper.join( self.room3, self.user2, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, expect_errcode=return_value[0], tok=self.tok2, expect_additional_fields=return_value[1], @@ -1270,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1335,71 +1364,93 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_messages_sent(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) @parameterized.expand( [ # Allow param( - name="NOT_SPAM", value="NOT_SPAM", expected_code=200, expected_fields={} + name="NOT_SPAM", + value="NOT_SPAM", + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + param( + name="False", + value=False, + expected_code=HTTPStatus.OK, + expected_fields={}, ), - param(name="False", value=False, expected_code=200, expected_fields={}), # Block param( name="scalene string", value="ANY OTHER STRING", - expected_code=403, + expected_code=HTTPStatus.FORBIDDEN, expected_fields={"errcode": "M_FORBIDDEN"}, ), param( name="True", value=True, - expected_code=403, + expected_code=HTTPStatus.FORBIDDEN, expected_fields={"errcode": "M_FORBIDDEN"}, ), param( name="Code", value=Codes.LIMIT_EXCEEDED, - expected_code=403, + expected_code=HTTPStatus.FORBIDDEN, expected_fields={"errcode": "M_LIMIT_EXCEEDED"}, ), param( name="Tuple", value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}), - expected_code=403, + expected_code=HTTPStatus.FORBIDDEN, expected_fields={ "errcode": "M_SERVER_NOT_TRUSTED", "additional_field": "12345", @@ -1584,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_normal_user_can_not_post_state_event(self) -> None: # Given I am a normal member of a room @@ -1598,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed because state events require PL>=50 - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " "user_level (0) < send_level (50)", @@ -1625,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1653,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1681,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (1)", @@ -1712,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): # Then I am not allowed because the public_chat config does not # affect this room, because this room is a private_chat - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (50)", @@ -1731,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase): def test_initial_sync(self) -> None: channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(self.room_id, channel.json_body["room_id"]) self.assertEqual("join", channel.json_body["membership"]) @@ -1774,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1785,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1824,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) @@ -1852,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) @@ -1869,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) @@ -1997,14 +2048,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): def test_restricted_no_auth(self) -> None: channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) def test_restricted_auth(self) -> None: self.register_user("user", "pass") tok = self.login("user", "pass") channel = self.make_request("GET", self.url, access_token=tok) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): @@ -2123,7 +2174,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined] "testserv", @@ -2140,7 +2191,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # The `get_public_rooms` should be called again if the first call fails # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] - HttpResponseException(404, "Not Found", b""), + HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), make_awaitable({}), ) @@ -2152,7 +2203,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined] [ @@ -2198,21 +2249,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): # Set a profile for the test user self.displayname = "test user" - data = {"displayname": self.displayname} - request_data = json.dumps(data) + request_data = {"displayname": self.displayname} channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) def test_per_room_profile_forbidden(self) -> None: - data = {"membership": "join", "displayname": "other test user"} - request_data = json.dumps(data) + request_data = {"membership": "join", "displayname": "other test user"} channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" @@ -2220,7 +2269,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -2228,7 +2277,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) @@ -2262,7 +2311,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2276,7 +2325,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2290,7 +2339,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2304,7 +2353,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2316,7 +2365,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2328,7 +2377,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2347,7 +2396,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2359,7 +2408,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ), access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_content = channel.json_body @@ -2407,7 +2456,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2437,7 +2486,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2472,7 +2521,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2552,16 +2601,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_labels(self) -> None: """Test that we can filter by a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2589,16 +2636,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2638,16 +2683,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /search request. """ - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2820,7 +2863,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["chunk"] @@ -2925,7 +2968,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2991,7 +3034,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=invited_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -3092,8 +3135,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -3122,8 +3164,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -3149,7 +3190,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - json.dumps(content), + content, access_token=self.room_owner_tok, ) self.assertEqual(channel.code, expected_code, channel.result) @@ -3283,7 +3324,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable(0)) + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_identity_handler().lookup_3pid = Mock( return_value=make_awaitable(None), @@ -3344,7 +3385,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable(0)) + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_identity_handler().lookup_3pid = Mock( return_value=make_awaitable(None), diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index e3efd1f1b0..b085c50356 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -606,11 +606,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8") channel = self.make_request( "POST", f"/rooms/{self.room_id}/read_markers", - body, + {ReceiptTypes.READ: res["event_id"]}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 5eb0f243f7..9a48e9286f 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -21,7 +21,6 @@ from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase -from synapse.events.snapshot import EventContext from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin from synapse.rest.client import account, login, profile, room @@ -113,14 +112,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth( - origin: str, - event: EventBase, - context: EventContext, - *args: Any, - **kwargs: Any, - ) -> EventContext: - return context + async def _check_event_auth(origin: Any, event: Any, context: Any) -> None: + pass hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 93f749744d..105d418698 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -136,7 +136,7 @@ class RestHelper: self.site, "POST", path, - json.dumps(content).encode("utf8"), + content, custom_headers=custom_headers, ) @@ -210,7 +210,7 @@ class RestHelper: self.site, "POST", path, - json.dumps(data).encode("utf8"), + data, ) assert ( @@ -309,7 +309,7 @@ class RestHelper: self.site, "PUT", path, - json.dumps(data).encode("utf8"), + data, ) assert ( @@ -392,7 +392,7 @@ class RestHelper: self.site, "PUT", path, - json.dumps(content or {}).encode("utf8"), + content or {}, custom_headers=custom_headers, ) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 79727c430f..d18fc13c21 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -126,7 +126,9 @@ class _TestImage: expected_scaled: The expected bytes from scaled thumbnailing, or None if test should just check for a valid image returned. expected_found: True if the file should exist on the server, or False if - a 404 is expected. + a 404/400 is expected. + unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or + False if the thumbnailing should succeed or a normal 404 is expected. """ data: bytes @@ -135,6 +137,7 @@ class _TestImage: expected_cropped: Optional[bytes] = None expected_scaled: Optional[bytes] = None expected_found: bool = True + unable_to_thumbnail: bool = False @parameterized_class( @@ -192,6 +195,7 @@ class _TestImage: b"image/gif", b".gif", expected_found=False, + unable_to_thumbnail=True, ), ), ], @@ -366,18 +370,29 @@ class MediaRepoTests(unittest.HomeserverTestCase): def test_thumbnail_crop(self) -> None: """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( - "crop", self.test_image.expected_cropped, self.test_image.expected_found + "crop", + self.test_image.expected_cropped, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) def test_thumbnail_scale(self) -> None: """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( - "scale", self.test_image.expected_scaled, self.test_image.expected_found + "scale", + self.test_image.expected_scaled, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) def test_invalid_type(self) -> None: """An invalid thumbnail type is never available.""" - self._test_thumbnail("invalid", None, False) + self._test_thumbnail( + "invalid", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} @@ -386,7 +401,12 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ Override the config to generate only scaled thumbnails, but request a cropped one. """ - self._test_thumbnail("crop", None, False) + self._test_thumbnail( + "crop", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} @@ -395,14 +415,22 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ Override the config to generate only cropped thumbnails, but request a scaled one. """ - self._test_thumbnail("scale", None, False) + self._test_thumbnail( + "scale", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) def test_thumbnail_repeated_thumbnail(self) -> None: """Test that fetching the same thumbnail works, and deleting the on disk thumbnail regenerates it. """ self._test_thumbnail( - "scale", self.test_image.expected_scaled, self.test_image.expected_found + "scale", + self.test_image.expected_scaled, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) if not self.test_image.expected_found: @@ -459,8 +487,24 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) def _test_thumbnail( - self, method: str, expected_body: Optional[bytes], expected_found: bool + self, + method: str, + expected_body: Optional[bytes], + expected_found: bool, + unable_to_thumbnail: bool = False, ) -> None: + """Test the given thumbnailing method works as expected. + + Args: + method: The thumbnailing method to use (crop, scale). + expected_body: The expected bytes from thumbnailing, or None if + test should just check for a valid image. + expected_found: True if the file should exist on the server, or False if + a 404/400 is expected. + unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or + False if the thumbnailing should succeed or a normal 404 is expected. + """ + params = "?width=32&height=32&method=" + method channel = make_request( self.reactor, @@ -496,6 +540,16 @@ class MediaRepoTests(unittest.HomeserverTestCase): else: # ensure that the result is at least some valid image Image.open(BytesIO(channel.result["body"])) + elif unable_to_thumbnail: + # A 400 with a JSON body. + self.assertEqual(channel.code, 400) + self.assertEqual( + channel.json_body, + { + "errcode": "M_UNKNOWN", + "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", + }, + ) else: # A 404 with a JSON body. self.assertEqual(channel.code, 404) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 38963ce4a7..46d829b062 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -143,7 +143,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) def test_simple(self): """Test that we cache events that we pull from the DB.""" @@ -160,7 +160,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): """ # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: # We keep hold of the event event though we never use it. @@ -170,7 +170,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: self.get_success(self.store.get_event(self.event_id)) @@ -345,7 +345,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) @contextmanager def blocking_get_event_calls( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index e8c53f16d9..ba40124c8a 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock - from twisted.test.proto_helpers import MemoryReactor +from synapse.rest import admin +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.util import Clock @@ -24,15 +24,14 @@ from tests.unittest import HomeserverTestCase USER_ID = "@user:example.com" -PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}] -HIGHLIGHT = [ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, -] - class EventPushActionsStoreTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main persist_events_store = hs.get_datastores().persist_events @@ -54,154 +53,118 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): ) def test_count_aggregation(self) -> None: - room_id = "!foo:example.com" - user_id = "@user1235:test" + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) - last_read_stream_ordering = [0] + last_event_id: str - def _assert_counts(noitf_count: int, highlight_count: int) -> None: + def _assert_counts( + noitf_count: int, unread_count: int, highlight_count: int + ) -> None: counts = self.get_success( self.store.db_pool.runInteraction( - "", - self.store._get_unread_counts_by_pos_txn, + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, room_id, user_id, - last_read_stream_ordering[0], ) ) self.assertEqual( counts, NotifCounts( notify_count=noitf_count, - unread_count=0, # Unread counts are tested in the sync tests. + unread_count=unread_count, highlight_count=highlight_count, ), ) - def _inject_actions(stream: int, action: list) -> None: - event = Mock() - event.room_id = room_id - event.event_id = f"$test{stream}:example.com" - event.internal_metadata.stream_ordering = stream - event.internal_metadata.is_outlier.return_value = False - event.depth = stream - - self.store._events_stream_cache.entity_has_changed(room_id, stream) - - self.get_success( - self.store.db_pool.simple_insert( - table="events", - values={ - "stream_ordering": stream, - "topological_ordering": stream, - "type": "m.room.message", - "room_id": room_id, - "processed": True, - "outlier": False, - "event_id": event.event_id, - }, - ) + def _create_event(highlight: bool = False) -> str: + result = self.helper.send_event( + room_id, + type="m.room.message", + content={"msgtype": "m.text", "body": user_id if highlight else "msg"}, + tok=other_token, ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id - self.get_success( - self.store.add_push_actions_to_staging( - event.event_id, - {user_id: action}, - False, - ) - ) - self.get_success( - self.store.db_pool.runInteraction( - "", - self.persist_events_store._set_push_actions_for_event_and_users_txn, - [(event, None)], - [(event, None)], - ) - ) - - def _rotate(stream: int) -> None: - self.get_success( - self.store.db_pool.runInteraction( - "rotate-receipts", self.store._handle_new_receipts_for_notifs_txn - ) - ) - - self.get_success( - self.store.db_pool.runInteraction( - "rotate-notifs", self.store._rotate_notifs_before_txn, stream - ) - ) - - def _mark_read(stream: int, depth: int) -> None: - last_read_stream_ordering[0] = stream + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + def _mark_read(event_id: str) -> None: self.get_success( self.store.insert_receipt( room_id, "m.read", user_id=user_id, - event_ids=[f"$test{stream}:example.com"], + event_ids=[event_id], data={}, ) ) - _assert_counts(0, 0) - _inject_actions(1, PlAIN_NOTIF) - _assert_counts(1, 0) - _rotate(1) - _assert_counts(1, 0) + _assert_counts(0, 0, 0) + _create_event() + _assert_counts(1, 1, 0) + _rotate() + _assert_counts(1, 1, 0) - _inject_actions(3, PlAIN_NOTIF) - _assert_counts(2, 0) - _rotate(3) - _assert_counts(2, 0) + event_id = _create_event() + _assert_counts(2, 2, 0) + _rotate() + _assert_counts(2, 2, 0) - _inject_actions(5, PlAIN_NOTIF) - _mark_read(3, 3) - _assert_counts(1, 0) + _create_event() + _mark_read(event_id) + _assert_counts(1, 1, 0) - _mark_read(5, 5) - _assert_counts(0, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) - _inject_actions(6, PlAIN_NOTIF) - _rotate(6) - _assert_counts(1, 0) - - self.get_success( - self.store.db_pool.simple_delete( - table="event_push_actions", keyvalues={"1": 1}, desc="" - ) - ) + _create_event() + _rotate() + _assert_counts(1, 1, 0) - _assert_counts(1, 0) + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 1, 0) - _mark_read(6, 6) - _assert_counts(0, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) - _inject_actions(8, HIGHLIGHT) - _assert_counts(1, 1) - _rotate(8) - _assert_counts(1, 1) + event_id = _create_event(True) + _assert_counts(1, 1, 1) + _rotate() + _assert_counts(1, 1, 1) # Check that adding another notification and rotating after highlight # works. - _inject_actions(10, PlAIN_NOTIF) - _rotate(10) - _assert_counts(2, 1) + _create_event() + _rotate() + _assert_counts(2, 2, 1) # Check that sending read receipts at different points results in the # right counts. - _mark_read(8, 8) - _assert_counts(1, 0) - _mark_read(10, 10) - _assert_counts(0, 0) - - _inject_actions(11, HIGHLIGHT) - _assert_counts(1, 1) - _mark_read(11, 11) - _assert_counts(0, 0) - _rotate(11) - _assert_counts(0, 0) + _mark_read(event_id) + _assert_counts(1, 1, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 1) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) + _rotate() + _assert_counts(0, 0, 0) def test_find_first_stream_ordering_after_ts(self) -> None: def add_event(so: int, ts: int) -> None: diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 8dfaa0559b..9c1182ed16 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -115,6 +115,6 @@ class PurgeTests(HomeserverTestCase): ) # The events aren't found. - self.store._invalidate_get_event_cache(create_event.event_id) + self.store._invalidate_local_get_event_cache(create_event.event_id) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 1218786d79..240b02cb9f 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -23,7 +23,6 @@ from synapse.util import Clock from tests import unittest from tests.server import TestHomeServer -from tests.test_utils import event_injection class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -110,60 +109,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) - def test_get_joined_users_from_context(self) -> None: - room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - bob_event = self.get_success( - event_injection.inject_member_event( - self.hs, room, self.u_bob, Membership.JOIN - ) - ) - - # first, create a regular event - event, context = self.get_success( - event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[bob_event.event_id], - type="m.test.1", - content={}, - ) - ) - - users = self.get_success( - self.store.get_joined_users_from_context(event, context) - ) - self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) - - # Regression test for #7376: create a state event whose key matches bob's - # user_id, but which is *not* a membership event, and persist that; then check - # that `get_joined_users_from_context` returns the correct users for the next event. - non_member_event = self.get_success( - event_injection.inject_event( - self.hs, - room_id=room, - sender=self.u_bob, - prev_event_ids=[bob_event.event_id], - type="m.test.2", - state_key=self.u_bob, - content={}, - ) - ) - event, context = self.get_success( - event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[non_member_event.event_id], - type="m.test.3", - content={}, - ) - ) - users = self.get_success( - self.store.get_joined_users_from_context(event, context) - ) - self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) - def test__null_byte_in_display_name_properly_handled(self) -> None: room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 371cd201af..e42d7b9ba0 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -19,7 +19,7 @@ from parameterized import parameterized from synapse import event_auth from synapse.api.constants import EventContentFields -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.storage.databases.main.events_worker import EventRedactBehaviour @@ -689,6 +689,45 @@ class EventAuthTestCase(unittest.TestCase): auth_events.values(), ) + def test_room_v10_rejects_string_power_levels(self) -> None: + pl_event_content = {"users_default": "42"} + pl_event = make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10), + "type": "m.room.power_levels", + "sender": "@test:test.com", + "state_key": "", + "content": pl_event_content, + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + }, + room_version=RoomVersions.V10, + ) + + pl_event2_content = {"events": {"m.room.name": "42", "m.room.power_levels": 42}} + pl_event2 = make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10), + "type": "m.room.power_levels", + "sender": "@test:test.com", + "state_key": "", + "content": pl_event2_content, + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + }, + room_version=RoomVersions.V10, + ) + + with self.assertRaises(SynapseError): + event_auth._check_power_levels( + pl_event.room_version, pl_event, {("fake_type", "fake_key"): pl_event2} + ) + + with self.assertRaises(SynapseError): + event_auth._check_power_levels( + pl_event.room_version, pl_event2, {("fake_type", "fake_key"): pl_event} + ) + # helpers for making events TEST_DOMAIN = "example.com" diff --git a/tests/test_federation.py b/tests/test_federation.py index 0cbef70bfa..779fad1f63 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -81,12 +81,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.handler = self.homeserver.get_federation_handler() federation_event_handler = self.homeserver.get_federation_event_handler() - async def _check_event_auth( - origin, - event, - context, - ): - return context + async def _check_event_auth(origin, event, context): + pass federation_event_handler._check_event_auth = _check_event_auth self.client = self.homeserver.get_federation_client() diff --git a/tests/test_server.py b/tests/test_server.py index fc4bce899c..2fe4411401 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -231,7 +231,7 @@ class OptionsResourceTests(unittest.TestCase): parse_listener_def({"type": "http", "port": 0}), self.resource, "1.0", - max_request_body_size=1234, + max_request_body_size=4096, reactor=self.reactor, ) diff --git a/tests/test_state.py b/tests/test_state.py index 6ca8d8f21d..bafd6d1750 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext -from synapse.state import StateHandler, StateResolutionHandler +from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry from synapse.util import Clock from synapse.util.macaroons import MacaroonGenerator @@ -99,6 +99,10 @@ class _DummyStore: state_group = self._next_group self._next_group += 1 + if current_state_ids is None: + current_state_ids = dict(self._group_to_state[prev_group]) + current_state_ids.update(delta_ids) + self._group_to_state[state_group] = dict(current_state_ids) return state_group @@ -760,3 +764,43 @@ class StateTestCase(unittest.TestCase): result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result + + def test_make_state_cache_entry(self): + "Test that calculating a prev_group and delta is correct" + + new_state = { + ("a", ""): "E", + ("b", ""): "E", + ("c", ""): "E", + ("d", ""): "E", + } + + # old_state_1 has fewer differences to new_state than old_state_2, but + # the delta involves deleting a key, which isn't allowed in the deltas, + # so we should pick old_state_2 as the prev_group. + + # `old_state_1` has two differences: `a` and `e` + old_state_1 = { + ("a", ""): "F", + ("b", ""): "E", + ("c", ""): "E", + ("d", ""): "E", + ("e", ""): "E", + } + + # `old_state_2` has three differences: `a`, `c` and `d` + old_state_2 = { + ("a", ""): "F", + ("b", ""): "E", + ("c", ""): "F", + ("d", ""): "F", + } + + entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2}) + + self.assertEqual(entry.prev_group, 2) + + # There are three changes from `old_state_2` to `new_state` + self.assertEqual( + entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"} + ) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 37fada5c53..d3c13cf14c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactorClock @@ -51,7 +50,7 @@ class TermsTestCase(unittest.HomeserverTestCase): def test_ui_auth(self): # Do a UI auth request - request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request_data = {"username": "kermit", "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"401", channel.result) @@ -82,16 +81,14 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(channel.json_body["params"], expected_params) # We have to complete the dummy auth stage before completing the terms stage - request_data = json.dumps( - { - "username": "kermit", - "password": "monkey", - "auth": { - "session": channel.json_body["session"], - "type": "m.login.dummy", - }, - } - ) + request_data = { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.dummy", + }, + } self.registration_handler.check_username = Mock(return_value=True) @@ -102,16 +99,14 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"401", channel.result) # Finish the UI auth for terms - request_data = json.dumps( - { - "username": "kermit", - "password": "monkey", - "auth": { - "session": channel.json_body["session"], - "type": "m.login.terms", - }, - } - ) + request_data = { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.terms", + }, + } channel = self.make_request(b"POST", self.url, request_data) # We're interested in getting a response that looks like a successful diff --git a/tests/test_visibility.py b/tests/test_visibility.py index f338af6c36..c385b2f8d4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -272,7 +272,7 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): "state_key": "@user:test", "content": {"membership": "invite"}, } - self.add_hashes_and_signatures(invite_pdu) + self.add_hashes_and_signatures_from_other_server(invite_pdu) invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id self.get_success( diff --git a/tests/unittest.py b/tests/unittest.py index c645dd3563..66ce92f4a6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -16,7 +16,6 @@ import gc import hashlib import hmac -import json import logging import secrets import time @@ -285,7 +284,7 @@ class HomeserverTestCase(TestCase): config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", - max_request_body_size=1234, + max_request_body_size=4096, reactor=self.reactor, ) @@ -619,20 +618,16 @@ class HomeserverTestCase(TestCase): want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac_digest = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": username, - "displayname": displayname, - "password": password, - "admin": admin, - "mac": want_mac_digest, - "inhibit_login": True, - } - ) - channel = self.make_request( - "POST", "/_synapse/admin/v1/register", body.encode("utf8") - ) + body = { + "nonce": nonce, + "username": username, + "displayname": displayname, + "password": password, + "admin": admin, + "mac": want_mac_digest, + "inhibit_login": True, + } + channel = self.make_request("POST", "/_synapse/admin/v1/register", body) self.assertEqual(channel.code, 200, channel.json_body) user_id = channel.json_body["user_id"] @@ -676,9 +671,7 @@ class HomeserverTestCase(TestCase): custom_headers: Optional[Iterable[CustomHeaderType]] = None, ) -> str: """ - Log in a user, and get an access token. Requires the Login API be - registered. - + Log in a user, and get an access token. Requires the Login API be registered. """ body = {"type": "m.login.password", "user": username, "password": password} if device_id: @@ -687,7 +680,7 @@ class HomeserverTestCase(TestCase): channel = self.make_request( "POST", "/_matrix/client/r0/login", - json.dumps(body).encode("utf8"), + body, custom_headers=custom_headers, ) self.assertEqual(channel.code, 200, channel.result) @@ -780,7 +773,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): verify_key_id, FetchKeyResult( verify_key=verify_key, - valid_until_ts=clock.time_msec() + 1000, + valid_until_ts=clock.time_msec() + 10000, ), ) ], @@ -838,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): client_ip=client_ip, ) - def add_hashes_and_signatures( + def add_hashes_and_signatures_from_other_server( self, event_dict: JsonDict, room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], diff --git a/tests/utils.py b/tests/utils.py index 424cc4c2a0..d2c6d1e852 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,6 +167,7 @@ def default_config( "local": {"per_second": 10000, "burst_count": 10000}, "remote": {"per_second": 10000, "burst_count": 10000}, }, + "rc_joins_per_room": {"per_second": 10000, "burst_count": 10000}, "rc_invites": { "per_room": {"per_second": 10000, "burst_count": 10000}, "per_user": {"per_second": 10000, "burst_count": 10000}, |