diff options
Diffstat (limited to 'tests')
51 files changed, 2573 insertions, 1166 deletions
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 18649c2c05..c86f783c5b 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase): # Check that we get rate limited after using that token. self.assertFalse(consume_at(11.1)) + + def test_record_action_which_doesnt_fill_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe two actions, leaving room in the bucket for one more. + limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0) + + # We should be able to take a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertTrue(success) + + # ... but not two. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + def test_record_action_which_fills_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe three actions, filling up the bucket. + limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0) + + # We should be unable to take a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + # If we wait 10 seconds to leak a token, we should be able to take one action... + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertTrue(success) + + # ... but not two. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertFalse(success) + + def test_record_action_which_overfills_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe four actions, exceeding the bucket. + limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0) + + # We should be prevented from taking a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + # If we wait 10 seconds to leak a token, we should be unable to take an action + # because the bucket is still full. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertFalse(success) + + # But after another 10 seconds we leak a second token, giving us room for + # action. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=20.0) + ) + self.assertTrue(success) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 9f1115dd23..c6dd99316a 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError @@ -50,7 +51,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) complexity = channel.json_body["v1"] self.assertTrue(complexity > 0, complexity) @@ -62,7 +63,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) complexity = channel.json_body["v1"] self.assertEqual(complexity, 1.23) diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 268a48d7ba..50e376f695 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -22,6 +22,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -38,42 +39,46 @@ class FederationClientTest(FederatingHomeserverTestCase): self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True) homeserver.get_federation_http_client().agent = self._mock_agent - def test_get_room_state(self): - creator = f"@creator:{self.OTHER_SERVER_NAME}" - test_room_id = "!room_id" + # Move clock up to somewhat realistic time so the PDU destination retry + # works (`now` needs to be larger than `0 + PDU_RETRY_TIME_MS`). + self.reactor.advance(1000000000) + + self.creator = f"@creator:{self.OTHER_SERVER_NAME}" + self.test_room_id = "!room_id" + def test_get_room_state(self): # mock up some events to use in the response. # In real life, these would have things in `prev_events` and `auth_events`, but that's # a bit annoying to mock up, and the code under test doesn't care, so we don't bother. - create_event_dict = self.add_hashes_and_signatures( + create_event_dict = self.add_hashes_and_signatures_from_other_server( { - "room_id": test_room_id, + "room_id": self.test_room_id, "type": "m.room.create", "state_key": "", - "sender": creator, - "content": {"creator": creator}, + "sender": self.creator, + "content": {"creator": self.creator}, "prev_events": [], "auth_events": [], "origin_server_ts": 500, } ) - member_event_dict = self.add_hashes_and_signatures( + member_event_dict = self.add_hashes_and_signatures_from_other_server( { - "room_id": test_room_id, + "room_id": self.test_room_id, "type": "m.room.member", - "sender": creator, - "state_key": creator, + "sender": self.creator, + "state_key": self.creator, "content": {"membership": "join"}, "prev_events": [], "auth_events": [], "origin_server_ts": 600, } ) - pl_event_dict = self.add_hashes_and_signatures( + pl_event_dict = self.add_hashes_and_signatures_from_other_server( { - "room_id": test_room_id, + "room_id": self.test_room_id, "type": "m.room.power_levels", - "sender": creator, + "sender": self.creator, "state_key": "", "content": {}, "prev_events": [], @@ -102,8 +107,8 @@ class FederationClientTest(FederatingHomeserverTestCase): # now fire off the request state_resp, auth_resp = self.get_success( self.hs.get_federation_client().get_room_state( - "yet_another_server", - test_room_id, + "yet.another.server", + self.test_room_id, "event_id", RoomVersions.V9, ) @@ -112,7 +117,7 @@ class FederationClientTest(FederatingHomeserverTestCase): # check the right call got made to the agent self._mock_agent.request.assert_called_once_with( b"GET", - b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id", + b"matrix://yet.another.server/_matrix/federation/v1/state/%21room_id?event_id=event_id", headers=mock.ANY, bodyProducer=None, ) @@ -130,6 +135,102 @@ class FederationClientTest(FederatingHomeserverTestCase): ["m.room.create", "m.room.member", "m.room.power_levels"], ) + def test_get_pdu_returns_nothing_when_event_does_not_exist(self): + """No event should be returned when the event does not exist""" + remote_pdu = self.get_success( + self.hs.get_federation_client().get_pdu( + ["yet.another.server"], + "event_should_not_exist", + RoomVersions.V9, + ) + ) + self.assertEqual(remote_pdu, None) + + def test_get_pdu(self): + """Test to make sure an event is returned by `get_pdu()`""" + self._get_pdu_once() + + def test_get_pdu_event_from_cache_is_pristine(self): + """Test that modifications made to events returned by `get_pdu()` + do not propagate back to to the internal cache (events returned should + be a copy). + """ + + # Get the PDU in the cache + remote_pdu = self._get_pdu_once() + + # Modify the the event reference. + # This change should not make it back to the `_get_pdu_cache`. + remote_pdu.internal_metadata.outlier = True + + # Get the event again. This time it should read it from cache. + remote_pdu2 = self.get_success( + self.hs.get_federation_client().get_pdu( + ["yet.another.server"], + remote_pdu.event_id, + RoomVersions.V9, + ) + ) + + # Sanity check that we are working against the same event + self.assertEqual(remote_pdu.event_id, remote_pdu2.event_id) + + # Make sure the event does not include modification from earlier + self.assertIsNotNone(remote_pdu2) + self.assertEqual(remote_pdu2.internal_metadata.outlier, False) + + def _get_pdu_once(self) -> EventBase: + """Retrieve an event via `get_pdu()` and assert that an event was returned. + Also used to prime the cache for subsequent test logic. + """ + message_event_dict = self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.test_room_id, + "type": "m.room.message", + "sender": self.creator, + "state_key": "", + "content": {}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 700, + "depth": 10, + } + ) + + # mock up the response, and have the agent return it + self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( + _mock_response( + { + "origin": "yet.another.server", + "origin_server_ts": 900, + "pdus": [ + message_event_dict, + ], + } + ) + ) + + remote_pdu = self.get_success( + self.hs.get_federation_client().get_pdu( + ["yet.another.server"], + "event_id", + RoomVersions.V9, + ) + ) + + # check the right call got made to the agent + self._mock_agent.request.assert_called_once_with( + b"GET", + b"matrix://yet.another.server/_matrix/federation/v1/event/event_id", + headers=mock.ANY, + bodyProducer=None, + ) + + self.assertIsNotNone(remote_pdu) + self.assertEqual(remote_pdu.internal_metadata.outlier, False) + + return remote_pdu + def _mock_response(resp: JsonDict): body = json.dumps(resp).encode("utf-8") diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 413b3c9426..3a6ef221ae 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from parameterized import parameterized @@ -20,7 +21,6 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin @@ -59,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/get_missing_events/%s" % (room_1,), query_content, ) - self.assertEqual(400, channel.code, channel.result) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") @@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,) ) - self.assertEqual(403, channel.code, channel.result) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -148,13 +148,13 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): tok2 = self.login("fozzie", "bear") self.helper.join(self._room_id, second_member_user_id, tok=tok2) - def _make_join(self, user_id) -> JsonDict: + def _make_join(self, user_id: str) -> JsonDict: channel = self.make_signed_federation_request( "GET", f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}" f"?ver={DEFAULT_ROOM_VERSION}", ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) return channel.json_body def test_send_join(self): @@ -163,18 +163,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): join_result = self._make_join(joining_user) join_event_dict = join_result["event"] - add_hashes_and_signatures( - KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + self.add_hashes_and_signatures_from_other_server( join_event_dict, - signature_name=self.OTHER_SERVER_NAME, - signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], ) channel = self.make_signed_federation_request( "PUT", f"/_matrix/federation/v2/send_join/{self._room_id}/x", content=join_event_dict, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # we should get complete room state back returned_state = [ @@ -220,18 +218,16 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): join_result = self._make_join(joining_user) join_event_dict = join_result["event"] - add_hashes_and_signatures( - KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + self.add_hashes_and_signatures_from_other_server( join_event_dict, - signature_name=self.OTHER_SERVER_NAME, - signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], ) channel = self.make_signed_federation_request( "PUT", f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", content=join_event_dict, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # expect a reduced room state returned_state = [ @@ -264,6 +260,67 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) + def test_make_join_respects_room_join_rate_limit(self) -> None: + # In the test setup, two users join the room. Since the rate limiter burst + # count is 3, a new make_join request to the room should be accepted. + + joining_user = "@ronniecorbett:" + self.OTHER_SERVER_NAME + self._make_join(joining_user) + + # Now have a new local user join the room. This saturates the rate limiter + # bucket, so the next make_join should be denied. + new_local_user = self.register_user("animal", "animal") + token = self.login("animal", "animal") + self.helper.join(self._room_id, new_local_user, tok=token) + + joining_user = "@ronniebarker:" + self.OTHER_SERVER_NAME + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self._room_id}/{joining_user}" + f"?ver={DEFAULT_ROOM_VERSION}", + ) + self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body) + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) + def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None: + # Make two make_join requests up front. (These are rate limited, but do not + # contribute to the rate limit.) + join_event_dicts = [] + for i in range(2): + joining_user = f"@misspiggy{i}:{self.OTHER_SERVER_NAME}" + join_result = self._make_join(joining_user) + join_event_dict = join_result["event"] + self.add_hashes_and_signatures_from_other_server( + join_event_dict, + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + ) + join_event_dicts.append(join_event_dict) + + # In the test setup, two users join the room. Since the rate limiter burst + # count is 3, the first send_join should be accepted... + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self._room_id}/join0", + content=join_event_dicts[0], + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # ... but the second should be denied. + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self._room_id}/join1", + content=join_event_dicts[1], + ) + self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body) + + # NB: we could write a test which checks that the send_join event is seen + # by other workers over replication, and that they update their rate limit + # buckets accordingly. I'm going to assume that the join event gets sent over + # replication, at which point the tests.handlers.room_member test + # test_local_users_joining_on_another_worker_contribute_to_rate_limit + # is probably sufficient to reassure that the bucket is updated. + def _create_acl_event(content): return make_event_from_dict( diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index d21c11b716..0d048207b7 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict +from http import HTTPStatus from typing import Dict, List from synapse.api.constants import EventTypes, JoinRules, Membership @@ -255,7 +256,7 @@ class FederationKnockingTestCase( RoomVersions.V7.identifier, ), ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Note: We don't expect the knock membership event to be sent over federation as # part of the stripped room state, as the knocking homeserver already has that @@ -293,7 +294,7 @@ class FederationKnockingTestCase( % (room_id, signed_knock_event.event_id), signed_knock_event_json, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Check that we got the stripped room state in return room_state_events = channel.json_body["knock_state_events"] diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index d96d5aa138..b17af2725b 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -50,7 +50,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastores.return_value = Mock(main=self.mock_store) - self.mock_store.get_received_ts.return_value = make_awaitable(0) + self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( None @@ -76,9 +76,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [])), - make_awaitable((1, [event])), + self.mock_store.get_all_new_events_stream.side_effect = [ + make_awaitable((0, [], {})), + make_awaitable((1, [event], {event.event_id: 0})), ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) @@ -95,8 +95,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [event])), + self.mock_store.get_all_new_events_stream.side_effect = [ + make_awaitable((0, [event], {event.event_id: 0})), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -112,8 +112,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [event])), + self.mock_store.get_all_new_events_stream.side_effect = [ + make_awaitable((0, [event], {event.event_id: 0})), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 9afba7b0e8..8a0bb91f40 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, cast +from typing import cast from unittest import TestCase from twisted.test.proto_helpers import MemoryReactor @@ -50,8 +50,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main - self.state_storage_controller = hs.get_storage_controllers().state - self._event_auth_handler = hs.get_event_auth_handler() return hs def test_exchange_revoked_invite(self) -> None: @@ -225,9 +223,10 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # we need a user on the remote server to be a member, so that we can send # extremity-causing events. + remote_server_user_id = f"@user:{self.OTHER_SERVER_NAME}" self.get_success( event_injection.inject_member_event( - self.hs, room_id, f"@user:{self.OTHER_SERVER_NAME}", "join" + self.hs, room_id, remote_server_user_id, "join" ) ) @@ -247,9 +246,15 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # create more than is 5 which corresponds to the number of backward # extremities we slice off in `_maybe_backfill_inner` federation_event_handler = self.hs.get_federation_event_handler() + auth_events = [ + ev + for ev in current_state + if (ev.type, ev.state_key) + in {("m.room.create", ""), ("m.room.member", remote_server_user_id)} + ] for _ in range(0, 8): event = make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "origin_server_ts": 1, "type": "m.room.message", @@ -258,15 +263,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): "body": "message connected to fake event", }, "room_id": room_id, - "sender": f"@user:{self.OTHER_SERVER_NAME}", + "sender": remote_server_user_id, "prev_events": [ ev1.event_id, # We're creating an backward extremity each time thanks # to this fake event generate_fake_event_id(), ], - # lazy: *everything* is an auth event - "auth_events": [ev.event_id for ev in current_state], + "auth_events": [ev.event_id for ev in auth_events], "depth": ev1.depth + 1, }, room_version, @@ -308,142 +312,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) self.get_success(d) - def test_backfill_floating_outlier_membership_auth(self) -> None: - """ - As the local homeserver, check that we can properly process a federated - event from the OTHER_SERVER with auth_events that include a floating - membership event from the OTHER_SERVER. - - Regression test, see #10439. - """ - OTHER_SERVER = "otherserver" - OTHER_USER = "@otheruser:" + OTHER_SERVER - - # create the room - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - room_id = self.helper.create_room_as( - room_creator=user_id, - is_public=True, - tok=tok, - extra_content={ - "preset": "public_chat", - }, - ) - room_version = self.get_success(self.store.get_room_version(room_id)) - - prev_event_ids = self.get_success(self.store.get_prev_events_for_room(room_id)) - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = self.get_success(self.store.get_max_depth_of(prev_event_ids)) - # mapping from (type, state_key) -> state_event_id - assert most_recent_prev_event_id is not None - prev_state_map = self.get_success( - self.state_storage_controller.get_state_ids_for_event( - most_recent_prev_event_id - ) - ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - auth_events = list( - self.get_success(self.store.get_events(auth_event_ids)).values() - ) - - # build a floating outlier member state event - fake_prev_event_id = "$" + random_string(43) - member_event_dict = { - "type": EventTypes.Member, - "content": { - "membership": "join", - }, - "state_key": OTHER_USER, - "room_id": room_id, - "sender": OTHER_USER, - "depth": most_recent_prev_event_depth, - "prev_events": [fake_prev_event_id], - "origin_server_ts": self.clock.time_msec(), - "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}}, - } - builder = self.hs.get_event_builder_factory().for_room_version( - room_version, member_event_dict - ) - member_event = self.get_success( - builder.build( - prev_event_ids=member_event_dict["prev_events"], - auth_event_ids=self._event_auth_handler.compute_auth_events( - builder, - prev_state_map, - for_verification=False, - ), - depth=member_event_dict["depth"], - ) - ) - # Override the signature added from "test" homeserver that we created the event with - member_event.signatures = member_event_dict["signatures"] - - # Add the new member_event to the StateMap - updated_state_map = dict(prev_state_map) - updated_state_map[ - (member_event.type, member_event.state_key) - ] = member_event.event_id - auth_events.append(member_event) - - # build and send an event authed based on the member event - message_event_dict = { - "type": EventTypes.Message, - "content": {}, - "room_id": room_id, - "sender": OTHER_USER, - "depth": most_recent_prev_event_depth, - "prev_events": prev_event_ids.copy(), - "origin_server_ts": self.clock.time_msec(), - "signatures": {OTHER_SERVER: {"ed25519:key_version": "SomeSignatureHere"}}, - } - builder = self.hs.get_event_builder_factory().for_room_version( - room_version, message_event_dict - ) - message_event = self.get_success( - builder.build( - prev_event_ids=message_event_dict["prev_events"], - auth_event_ids=self._event_auth_handler.compute_auth_events( - builder, - updated_state_map, - for_verification=False, - ), - depth=message_event_dict["depth"], - ) - ) - # Override the signature added from "test" homeserver that we created the event with - message_event.signatures = message_event_dict["signatures"] - - # Stub the /event_auth response from the OTHER_SERVER - async def get_event_auth( - destination: str, room_id: str, event_id: str - ) -> List[EventBase]: - return [ - event_from_pdu_json(ae.get_pdu_json(), room_version=room_version) - for ae in auth_events - ] - - self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment] - - with LoggingContext("receive_pdu"): - # Fake the OTHER_SERVER federating the message event over to our local homeserver - d = run_in_background( - self.hs.get_federation_event_handler().on_receive_pdu, - OTHER_SERVER, - message_event, - ) - self.get_success(d) - - # Now try and get the events on our local homeserver - stored_event = self.get_success( - self.store.get_event(message_event.event_id, allow_none=True) - ) - self.assertTrue(stored_event is not None) - @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 1a36c25c41..51c8dd6498 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -98,14 +98,13 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): auth_event_ids = [ initial_state_map[("m.room.create", "")], initial_state_map[("m.room.power_levels", "")], - initial_state_map[("m.room.join_rules", "")], member_event.event_id, ] # mock up a load of state events which we are missing state_events = [ make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "type": "test_state_type", "state_key": f"state_{i}", @@ -132,7 +131,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # Depending on the test, we either persist this upfront (as an outlier), # or let the server request it. prev_event = make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "type": "test_regular_type", "room_id": room_id, @@ -166,7 +165,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): # mock up a regular event to pass into _process_pulled_event pulled_event = make_event_from_dict( - self.add_hashes_and_signatures( + self.add_hashes_and_signatures_from_other_server( { "type": "test_regular_type", "room_id": room_id, diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 82b3bb3b73..4c62449c89 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -14,6 +14,7 @@ """Tests for the password_auth_provider interface""" +from http import HTTPStatus from typing import Any, Type, Union from unittest.mock import Mock @@ -188,14 +189,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # check_password must return an awaitable mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._send_password_login("u", "p") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") mock_password_provider.reset_mock() # login with mxid should work too channel = self._send_password_login("@u:bz", "p") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@u:bz", channel.json_body["user_id"]) mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") mock_password_provider.reset_mock() @@ -204,7 +205,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # in these cases, but at least we can guard against the API changing # unexpectedly channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"]) mock_password_provider.check_password.assert_called_once_with( "@ USER🙂NAME :test", " pASS😢word " @@ -258,10 +259,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # check_password must return an awaitable mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("u", "p") - self.assertEqual(channel.code, 403, channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@localuser:test", channel.json_body["user_id"]) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @@ -382,7 +383,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("u", "p") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_password.assert_not_called() @override_config(legacy_providers_config(LegacyCustomAuthProvider)) @@ -406,14 +407,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login with missing param should be rejected channel = self._send_login("test.login_type", "u") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) channel = self._send_login("test.login_type", "u", test_field="y") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( "u", "test.login_type", {"test_field": "y"} @@ -427,7 +428,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( " USER🙂NAME ", "test.login_type", {"test_field": " abc "} @@ -510,7 +511,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ("@user:bz", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( "u", "test.login_type", {"test_field": "y"} @@ -549,7 +550,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() @override_config( @@ -584,7 +585,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() @override_config( @@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # login shouldn't work and should be rejected with a 400 ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_password.assert_not_called() @@ -646,13 +647,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) tok1 = channel.json_body["access_token"] channel = self._send_login( "test.login_type", "localuser", test_field="", device_id="dev2" ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # make the initial request which returns a 401 channel = self._delete_device(tok1, "dev2") @@ -721,7 +722,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # password login shouldn't work and should be rejected with a 400 # ("unknown login type") channel = self._send_password_login("localuser", "localpass") - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) def test_on_logged_out(self): """Tests that the on_logged_out callback is called when the user logs out.""" @@ -884,7 +885,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): }, access_token=tok, ) - self.assertEqual(channel.code, 403, channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) self.assertEqual( channel.json_body["errcode"], Codes.THREEPID_DENIED, @@ -906,7 +907,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): }, access_token=tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertIn("sid", channel.json_body) m.assert_called_once_with("email", "bar@test.com", registration) @@ -949,12 +950,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "register", {"auth": {"session": session, "type": LoginType.DUMMY}}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) return channel.json_body def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["flows"] def _send_password_login(self, user: str, password: str) -> FakeChannel: diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py new file mode 100644 index 0000000000..254e7e4b80 --- /dev/null +++ b/tests/handlers/test_room_member.py @@ -0,0 +1,290 @@ +from http import HTTPStatus +from unittest.mock import Mock, patch + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +import synapse.rest.client.login +import synapse.rest.client.room +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import LimitExceededError +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import FrozenEventV3 +from synapse.federation.federation_client import SendJoinResult +from synapse.server import HomeServer +from synapse.types import UserID, create_requester +from synapse.util import Clock + +from tests.replication._base import RedisMultiWorkerStreamTestCase +from tests.server import make_request +from tests.test_utils import make_awaitable +from tests.unittest import FederatingHomeserverTestCase, override_config + + +class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + synapse.rest.client.room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_room_member_handler() + + # Create three users. + self.alice = self.register_user("alice", "pass") + self.alice_token = self.login("alice", "pass") + self.bob = self.register_user("bob", "pass") + self.bob_token = self.login("bob", "pass") + self.chris = self.register_user("chris", "pass") + self.chris_token = self.login("chris", "pass") + + # Create a room on this homeserver. Note that this counts as a join: it + # contributes to the rate limter's count of actions + self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) + + self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}" + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None: + # The rate limiter has accumulated one token from Alice's join after the create + # event. + # Try joining the room as Bob. + self.get_success( + self.handler.update_membership( + requester=create_requester(self.bob), + target=UserID.from_string(self.bob), + room_id=self.room_id, + action=Membership.JOIN, + ) + ) + + # The rate limiter bucket is full. A second join should be denied. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ), + LimitExceededError, + ) + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None: + # The rate limiter has accumulated one token from Alice's join after the create + # event. Alice should still be able to change her displayname. + self.get_success( + self.handler.update_membership( + requester=create_requester(self.alice), + target=UserID.from_string(self.alice), + room_id=self.room_id, + action=Membership.JOIN, + content={"displayname": "Alice Cooper"}, + ) + ) + + # Still room in the limiter bucket. Chris's join should be accepted. + self.get_success( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ) + ) + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}}) + def test_remote_joins_contribute_to_rate_limit(self) -> None: + # Join once, to fill the rate limiter bucket. + # + # To do this we have to mock the responses from the remote homeserver. + # We also patch out a bunch of event checks on our end. All we're really + # trying to check here is that remote joins will bump the rate limter when + # they are persisted. + create_event_source = { + "auth_events": [], + "content": { + "creator": f"@creator:{self.OTHER_SERVER_NAME}", + "room_version": self.hs.config.server.default_room_version.identifier, + }, + "depth": 0, + "origin_server_ts": 0, + "prev_events": [], + "room_id": self.intially_unjoined_room_id, + "sender": f"@creator:{self.OTHER_SERVER_NAME}", + "state_key": "", + "type": EventTypes.Create, + } + self.add_hashes_and_signatures_from_other_server( + create_event_source, + self.hs.config.server.default_room_version, + ) + create_event = FrozenEventV3( + create_event_source, + self.hs.config.server.default_room_version, + {}, + None, + ) + + join_event_source = { + "auth_events": [create_event.event_id], + "content": {"membership": "join"}, + "depth": 1, + "origin_server_ts": 100, + "prev_events": [create_event.event_id], + "sender": self.bob, + "state_key": self.bob, + "room_id": self.intially_unjoined_room_id, + "type": EventTypes.Member, + } + add_hashes_and_signatures( + self.hs.config.server.default_room_version, + join_event_source, + self.hs.hostname, + self.hs.signing_key, + ) + join_event = FrozenEventV3( + join_event_source, + self.hs.config.server.default_room_version, + {}, + None, + ) + + mock_make_membership_event = Mock( + return_value=make_awaitable( + ( + self.OTHER_SERVER_NAME, + join_event, + self.hs.config.server.default_room_version, + ) + ) + ) + mock_send_join = Mock( + return_value=make_awaitable( + SendJoinResult( + join_event, + self.OTHER_SERVER_NAME, + state=[create_event], + auth_chain=[create_event], + partial_state=False, + servers_in_room=[], + ) + ) + ) + + with patch.object( + self.handler.federation_handler.federation_client, + "make_membership_event", + mock_make_membership_event, + ), patch.object( + self.handler.federation_handler.federation_client, + "send_join", + mock_send_join, + ), patch( + "synapse.event_auth._is_membership_change_allowed", + return_value=None, + ), patch( + "synapse.handlers.federation_event.check_state_dependent_auth_rules", + return_value=None, + ): + self.get_success( + self.handler.update_membership( + requester=create_requester(self.bob), + target=UserID.from_string(self.bob), + room_id=self.intially_unjoined_room_id, + action=Membership.JOIN, + remote_room_hosts=[self.OTHER_SERVER_NAME], + ) + ) + + # Try to join as Chris. Should get denied. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.intially_unjoined_room_id, + action=Membership.JOIN, + remote_room_hosts=[self.OTHER_SERVER_NAME], + ), + LimitExceededError, + ) + + # TODO: test that remote joins to a room are rate limited. + # Could do this by setting the burst count to 1, then: + # - remote-joining a room + # - immediately leaving + # - trying to remote-join again. + + +class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + synapse.rest.client.room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_room_member_handler() + + # Create three users. + self.alice = self.register_user("alice", "pass") + self.alice_token = self.login("alice", "pass") + self.bob = self.register_user("bob", "pass") + self.bob_token = self.login("bob", "pass") + self.chris = self.register_user("chris", "pass") + self.chris_token = self.login("chris", "pass") + + # Create a room on this homeserver. + # Note that this counts as a + self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) + self.intially_unjoined_room_id = "!example:otherhs" + + @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + def test_local_users_joining_on_another_worker_contribute_to_rate_limit( + self, + ) -> None: + # The rate limiter has accumulated one token from Alice's join after the create + # event. + self.replicate() + + # Spawn another worker and have bob join via it. + worker_app = self.make_worker_hs( + "synapse.app.generic_worker", extra_config={"worker_name": "other worker"} + ) + worker_site = self._hs_to_site[worker_app] + channel = make_request( + self.reactor, + worker_site, + "POST", + f"/_matrix/client/v3/rooms/{self.room_id}/join", + access_token=self.bob_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # wait for join to arrive over replication + self.replicate() + + # Try to join as Chris on the worker. Should get denied because Alice + # and Bob have both joined the room. + self.get_failure( + worker_app.get_room_member_handler().update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ), + LimitExceededError, + ) + + # Try to join as Chris on the original worker. Should get denied because Alice + # and Bob have both joined the room. + self.get_failure( + self.handler.update_membership( + requester=create_requester(self.chris), + target=UserID.from_string(self.chris), + room_id=self.room_id, + action=Membership.JOIN, + ), + LimitExceededError, + ) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index ecc7cc6461..e3f38fbcc5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -159,7 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear() # The rooms should be excluded from the sync response. diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index e430941d27..3b14c76d7e 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast + from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactorClock @@ -40,6 +42,15 @@ from tests.unittest import TestCase class LogContextScopeManagerTestCase(TestCase): + """ + Test logging contexts and active opentracing spans. + + There's casts throughout this from generic opentracing objects (e.g. + opentracing.Span) to the ones specific to Jaeger since they have additional + properties that these tests depend on. This is safe since the only supported + opentracing backend is Jaeger. + """ + if LogContextScopeManager is None: skip = "Requires opentracing" # type: ignore[unreachable] if jaeger_client is None: @@ -50,7 +61,7 @@ class LogContextScopeManagerTestCase(TestCase): # global variables that power opentracing. We create our own tracer instance # and test with it. - scope_manager = LogContextScopeManager({}) + scope_manager = LogContextScopeManager() config = jaeger_client.config.Config( config={}, service_name="test", scope_manager=scope_manager ) @@ -69,7 +80,7 @@ class LogContextScopeManagerTestCase(TestCase): # start_active_span should start and activate a span. scope = start_active_span("span", tracer=self._tracer) - span = scope.span + span = cast(jaeger_client.Span, scope.span) self.assertEqual(self._tracer.active_span, span) self.assertIsNotNone(span.start_time) @@ -91,6 +102,7 @@ class LogContextScopeManagerTestCase(TestCase): with LoggingContext("root context"): with start_active_span("root span", tracer=self._tracer) as root_scope: self.assertEqual(self._tracer.active_span, root_scope.span) + root_context = cast(jaeger_client.SpanContext, root_scope.span.context) scope1 = start_active_span( "child1", @@ -99,9 +111,8 @@ class LogContextScopeManagerTestCase(TestCase): self.assertEqual( self._tracer.active_span, scope1.span, "child1 was not activated" ) - self.assertEqual( - scope1.span.context.parent_id, root_scope.span.context.span_id - ) + context1 = cast(jaeger_client.SpanContext, scope1.span.context) + self.assertEqual(context1.parent_id, root_context.span_id) scope2 = start_active_span_follows_from( "child2", @@ -109,17 +120,18 @@ class LogContextScopeManagerTestCase(TestCase): tracer=self._tracer, ) self.assertEqual(self._tracer.active_span, scope2.span) - self.assertEqual( - scope2.span.context.parent_id, scope1.span.context.span_id - ) + context2 = cast(jaeger_client.SpanContext, scope2.span.context) + self.assertEqual(context2.parent_id, context1.span_id) with scope1, scope2: pass # the root scope should be restored self.assertEqual(self._tracer.active_span, root_scope.span) - self.assertIsNotNone(scope2.span.end_time) - self.assertIsNotNone(scope1.span.end_time) + span2 = cast(jaeger_client.Span, scope2.span) + span1 = cast(jaeger_client.Span, scope1.span) + self.assertIsNotNone(span2.end_time) + self.assertIsNotNone(span1.end_time) self.assertIsNone(self._tracer.active_span) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 9b623d0033..718f489577 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -16,13 +16,23 @@ from typing import Dict, Optional, Set, Tuple, Union import frozendict +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions +from synapse.appservice import ApplicationService from synapse.events import FrozenEvent from synapse.push import push_rule_evaluator from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent +from synapse.rest.client import login, register, room +from synapse.server import HomeServer +from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest +from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): @@ -354,3 +364,78 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "event_type": "*.reaction", } self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) + + +class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): + """Tests for the bulk push rule evaluator""" + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + # Define an application service so that we can register appservice users + self._service_token = "some_token" + self._service = ApplicationService( + self._service_token, + "as1", + "@as.sender:test", + namespaces={ + "users": [ + {"regex": "@_as_.*:test", "exclusive": True}, + {"regex": "@as.sender:test", "exclusive": True}, + ] + }, + msc3202_transaction_extensions=True, + ) + self.hs.get_datastores().main.services_cache = [self._service] + self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( + [self._service] + ) + + self._as_user, _ = self.register_appservice_user( + "_as_user", self._service_token + ) + + self.evaluator = self.hs.get_bulk_push_rule_evaluator() + + def test_ignore_appservice_users(self) -> None: + "Test that we don't generate push for appservice users" + + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as(user_id, tok=token) + self.get_success( + inject_member_event(self.hs, room_id, self._as_user, Membership.JOIN) + ) + + event, context = self.get_success( + create_event( + self.hs, + type=EventTypes.Message, + room_id=room_id, + sender=user_id, + content={"body": "test", "msgtype": "m.text"}, + ) + ) + + # Assert the returned push rules do not contain the app service user + rules = self.get_success(self.evaluator._get_rules_for_event(event)) + self.assertTrue(self._as_user not in rules) + + # Assert that no push actions have been added to the staging table (the + # sender should not be pushed for the event) + users_with_push_actions = self.get_success( + self.hs.get_datastores().main.db_pool.simple_select_onecol( + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcol="user_id", + desc="test_ignore_appservice_users", + ) + ) + + self.assertEqual(len(users_with_push_actions), 0) diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py deleted file mode 100644 index 1524087c43..0000000000 --- a/tests/replication/slave/storage/test_account_data.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore - -from ._base import BaseSlavedStoreTestCase - -USER_ID = "@feeling:blue" -TYPE = "my.type" - - -class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): - - STORE_TYPE = SlavedAccountDataStore - - def test_user_account_data(self): - self.get_success( - self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) - ) - self.replicate() - self.check( - "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 1} - ) - - self.get_success( - self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) - ) - self.replicate() - self.check( - "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 2} - ) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ca6af9417b..2526136ff8 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -21,7 +21,7 @@ from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.errors import Codes from synapse.handlers.pagination import PaginationHandler from synapse.rest.client import directory, events, login, room @@ -1130,6 +1130,8 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("guest_access", r) self.assertIn("history_visibility", r) self.assertIn("state_events", r) + self.assertIn("room_type", r) + self.assertIsNone(r["room_type"]) # Check that the correct number of total rooms was returned self.assertEqual(channel.json_body["total_rooms"], total_rooms) @@ -1229,7 +1231,11 @@ class RoomTestCase(unittest.HomeserverTestCase): def test_correct_room_attributes(self) -> None: """Test the correct attributes for a room are returned""" # Create a test room - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id = self.helper.create_room_as( + self.admin_user, + tok=self.admin_user_tok, + extra_content={"creation_content": {"type": RoomTypes.SPACE}}, + ) test_alias = "#test:test" test_room_name = "something" @@ -1306,6 +1312,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertEqual(room_id, r["room_id"]) self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_alias, r["canonical_alias"]) + self.assertEqual(RoomTypes.SPACE, r["room_type"]) def test_room_list_sort_order(self) -> None: """Test room list sort ordering. alphabetical name versus number of members, @@ -1579,8 +1586,8 @@ class RoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) - self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) + self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id")) + self.assertEqual("ж", channel.json_body["rooms"][0].get("name")) def test_single_room(self) -> None: """Test that a single room can be requested correctly""" @@ -1630,7 +1637,7 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("guest_access", channel.json_body) self.assertIn("history_visibility", channel.json_body) self.assertIn("state_events", channel.json_body) - + self.assertIn("room_type", channel.json_body) self.assertEqual(room_id_1, channel.json_body["room_id"]) def test_single_room_devices(self) -> None: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 0d44102237..12db68d564 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) @@ -1488,7 +1488,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): if channel.code != HTTPStatus.OK: raise HttpResponseException( - channel.code, channel.result["reason"], channel.json_body + channel.code, channel.result["reason"], channel.result["body"] ) # Set monthly active users to the limit @@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123", "admin": False}, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # Admin user is not blocked by mau anymore - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertFalse(channel.json_body["admin"]) @@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content=body, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) @@ -1636,6 +1636,41 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(len(pushers), 0) + @override_config( + { + "email": { + "enable_notifs": True, + "notif_for_new_users": True, + "notif_from": "test@example.com", + }, + "public_baseurl": "https://example.com", + } + ) + def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None: + """ + Check that a new regular user is created successfully when they have a msisdn + threepid and email notif_for_new_users is set to True. + """ + url = self.url_prefix % "@bob:test" + + # Create user + body = { + "password": "abc123", + "threepids": [{"medium": "msisdn", "address": "1234567890"}], + } + + channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body, + ) + + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"]) + def test_set_password(self) -> None: """ Test setting a new password for another user. @@ -2372,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): content={"password": "abc123"}, ) - self.assertEqual(201, channel.code, msg=channel.json_body) + self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("bob", channel.json_body["displayname"]) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index a43a137273..7ae926dc9c 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os import re from email.parser import Parser +from http import HTTPStatus from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock @@ -95,10 +95,8 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") - ) - self.assertEqual(channel.code, 403, channel.result) + channel = self.make_request("POST", "/_matrix/client/r0/login", body) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) def test_basic_password_reset(self) -> None: """Test basic password reset flow""" @@ -347,7 +345,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Now POST to the same endpoint, mimicking the same behaviour as clicking the # password reset confirm button @@ -362,7 +360,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -390,7 +388,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): new_password: str, session_id: str, client_secret: str, - expected_code: int = 200, + expected_code: int = HTTPStatus.OK, ) -> None: channel = self.make_request( "POST", @@ -479,16 +477,14 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertEqual(memberships[0].room_id, room_id, memberships) def deactivate(self, user_id: str, tok: str) -> None: - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) @@ -715,7 +711,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -725,7 +723,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email(self) -> None: @@ -747,7 +745,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -756,7 +754,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_delete_email_if_disabled(self) -> None: @@ -781,7 +779,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) @@ -817,7 +817,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -827,7 +829,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) def test_no_valid_token(self) -> None: @@ -852,7 +854,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -862,7 +866,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @override_config({"next_link_domain_whitelist": None}) @@ -872,7 +876,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/good/site", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -884,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="some-protocol://abcdefghijklmopqrstuvwxyz", - expect_code=200, + expect_code=HTTPStatus.OK, ) @override_config({"next_link_domain_whitelist": None}) @@ -895,7 +899,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="file:///host/path", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) @@ -907,28 +911,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link=None, - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.com/some/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://example.org/some/also/good/page", - expect_code=200, + expect_code=HTTPStatus.OK, ) self._request_token( "something@example.com", "some_secret", next_link="https://bad.example.org/some/bad/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) @override_config({"next_link_domain_whitelist": []}) @@ -940,7 +944,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): "something@example.com", "some_secret", next_link="https://example.com/a/page", - expect_code=400, + expect_code=HTTPStatus.BAD_REQUEST, ) def _request_token( @@ -948,8 +952,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): email: str, client_secret: str, next_link: Optional[str] = None, - expect_code: int = 200, - ) -> str: + expect_code: int = HTTPStatus.OK, + ) -> Optional[str]: """Request a validation token to add an email address to a user's account Args: @@ -959,7 +963,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): expect_code: Expected return code of the call Returns: - The ID of the new threepid validation session + The ID of the new threepid validation session, or None if the response + did not contain a session ID. """ body = {"client_secret": client_secret, "email": email, "send_attempt": 1} if next_link: @@ -992,7 +997,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_error, channel.json_body["error"]) @@ -1001,7 +1008,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" @@ -1051,7 +1058,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -1060,7 +1067,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} @@ -1091,7 +1098,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that not providing any MXID raises an error.""" self._test_status( users=None, - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.MISSING_PARAM, ) @@ -1099,7 +1106,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): """Tests that providing an invalid MXID raises an error.""" self._test_status( users=["bad:test"], - expected_status_code=400, + expected_status_code=HTTPStatus.BAD_REQUEST, expected_errcode=Codes.INVALID_PARAM, ) @@ -1285,7 +1292,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): def _test_status( self, users: Optional[List[str]], - expected_status_code: int = 200, + expected_status_code: int = HTTPStatus.OK, expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_failures: Optional[List[str]] = None, expected_errcode: Optional[str] = None, diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index aca03afd0e..7a88aa2cda 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor +from synapse.appservice import ApplicationService from synapse.rest import admin from synapse.rest.client import directory, login, room from synapse.server import HomeServer @@ -96,8 +96,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # We use deliberately a localpart under the length threshold so # that we can make sure that the check is done on the whole alias. - data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -109,8 +108,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # Check with an alias of allowed length. There should already be # a test that ensures it works in test_register.py, but let's be # as cautious as possible here. - data = {"room_alias_name": random_string(5)} - request_data = json.dumps(data) + request_data = {"room_alias_name": random_string(5)} channel = self.make_request( "POST", url, request_data, access_token=self.user_tok ) @@ -129,6 +127,38 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + def test_deleting_alias_via_directory_appservice(self) -> None: + user_id = "@as:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + id="1234", + namespaces={"aliases": [{"regex": "#asns-*", "exclusive": True}]}, + sender=user_id, + ) + self.hs.get_datastores().main.services_cache.append(appservice) + + # Add an alias for the room, as the appservice + alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string() + request_data = {"room_id": self.room_id} + + channel = self.make_request( + "PUT", + f"/_matrix/client/r0/directory/room/{alias}", + request_data, + access_token=as_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + # Then try to remove the alias, as the appservice + channel = self.make_request( + "DELETE", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=as_token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + def test_deleting_nonexistant_alias(self) -> None: # Check that no alias exists alias = "#potato:test" @@ -159,8 +189,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.hs.hostname, ) - data = {"aliases": [self.random_alias(alias_length)]} - request_data = json.dumps(data) + request_data = {"aliases": [self.random_alias(alias_length)]} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok @@ -172,8 +201,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ) -> str: alias = self.random_alias(alias_length) url = "/_matrix/client/r0/directory/room/%s" % alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.user_tok @@ -181,6 +209,19 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, expected_code, channel.result) return alias + def test_invalid_alias(self) -> None: + alias = "#potato" + channel = self.make_request( + "GET", + f"/_matrix/client/r0/directory/room/{alias}", + access_token=self.user_tok, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertIn("error", channel.json_body, channel.json_body) + self.assertEqual( + channel.json_body["errcode"], "M_INVALID_PARAM", channel.json_body + ) + def random_alias(self, length: int) -> str: return RoomAlias(random_string(length), self.hs.hostname).to_string() diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 299b9d21e2..dc17c9d113 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -51,12 +50,11 @@ class IdentityTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK, channel.result) room_id = channel.json_body["room_id"] - params = { + request_data = { "id_server": "testis", "medium": "email", "address": "test@example.com", } - request_data = json.dumps(params) request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii") channel = self.make_request( b"POST", request_url, request_data, access_token=tok diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index f4ea1209d9..a2958f6959 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import time import urllib.parse -from typing import Any, Dict, List, Optional, Union +from http import HTTPStatus +from typing import Any, Dict, List, Optional from unittest.mock import Mock from urllib.parse import urlencode @@ -41,7 +41,7 @@ from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless try: - import jwt + from authlib.jose import jwk, jwt HAS_JWT = True except ImportError: @@ -261,20 +261,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) access_token = channel.json_body["access_token"] device_id = channel.json_body["device_id"] # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -288,7 +288,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # more requests with the expired token should still return a soft-logout self.reactor.advance(3600) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -296,7 +296,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self._delete_device(access_token_2, "kermit", "monkey", device_id) channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], False) @@ -307,7 +307,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( b"DELETE", "devices/" + device_id, access_token=access_token ) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) # check it's a UI-Auth fail self.assertEqual( set(channel.json_body.keys()), @@ -330,7 +330,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): access_token=access_token, content={"auth": auth}, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: @@ -341,14 +341,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -367,14 +367,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # we should now be able to make requests with the access token channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # time passes self.reactor.advance(24 * 3600) # ... and we should be soft-logouted channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["soft_logout"], True) @@ -399,7 +399,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", "/_matrix/client/v3/login", - json.dumps(body).encode("utf8"), + body, custom_headers=None, ) @@ -466,7 +466,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def test_get_login_flows(self) -> None: """GET /login should return password and SSO flows""" channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) expected_flow_types = [ "m.login.cas", @@ -494,14 +494,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker channel = self._make_sso_redirect_request(None) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers uri = location_headers[0] # hitting that picker should give us some HTML channel = self.make_request("GET", uri) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) # parse the form to check it has fields assumed elsewhere in this class html = channel.result["body"].decode("utf-8") @@ -530,7 +530,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + "&idp=cas", shorthand=False, ) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers cas_uri = location_headers[0] @@ -555,7 +555,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers saml_uri = location_headers[0] @@ -579,7 +579,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers oidc_uri = location_headers[0] @@ -606,7 +606,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) content_type_headers = channel.headers.getRawHeaders("Content-Type") assert content_type_headers self.assertTrue(content_type_headers[-1].startswith("text/html")) @@ -634,7 +634,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): "/login", content={"type": "m.login.token", "token": login_token}, ) - self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.json_body["user_id"], "@user1:test") def test_multi_sso_redirect_to_unknown(self) -> None: @@ -643,18 +643,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" channel = self._make_sso_redirect_request("xxx") - self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") def test_client_idp_redirect_to_oidc(self) -> None: """If the client pick a known IdP, redirect to it""" channel = self._make_sso_redirect_request("oidc") - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers oidc_uri = location_headers[0] @@ -765,7 +765,7 @@ class CASTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", cas_ticket_url) # Test that the response is HTML. - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) content_type_header_value = "" for header in channel.result.get("headers", []): if header[0] == b"Content-Type": @@ -841,7 +841,7 @@ class CASTestCase(unittest.HomeserverTestCase): self.assertIn(b"SSO account deactivated", channel.result["body"]) -@skip_unless(HAS_JWT, "requires jwt") +@skip_unless(HAS_JWT, "requires authlib") class JWTTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -866,11 +866,9 @@ class JWTTestCase(unittest.HomeserverTestCase): return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) - if isinstance(result, bytes): - return result.decode("ascii") - return result + header = {"alg": self.jwt_algorithm} + result: bytes = jwt.encode(header, payload, secret) + return result.decode("ascii") def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} @@ -902,7 +900,8 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Signature has expired" + channel.json_body["error"], + "JWT validation failed: expired_token: The token is expired", ) def test_login_jwt_not_before(self) -> None: @@ -912,7 +911,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - "JWT validation failed: The token is not yet valid (nbf)", + "JWT validation failed: invalid_token: The token is not valid yet", ) def test_login_no_sub(self) -> None: @@ -934,7 +933,8 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid issuer" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "iss"', ) # Not providing an issuer. @@ -943,7 +943,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - 'JWT validation failed: Token is missing the "iss" claim', + 'JWT validation failed: missing_claim: Missing "iss" claim', ) def test_login_iss_no_config(self) -> None: @@ -965,7 +965,8 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "aud"', ) # Not providing an audience. @@ -974,7 +975,7 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], - 'JWT validation failed: Token is missing the "aud" claim', + 'JWT validation failed: missing_claim: Missing "aud" claim', ) def test_login_aud_no_config(self) -> None: @@ -983,7 +984,8 @@ class JWTTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" + channel.json_body["error"], + 'JWT validation failed: invalid_claim: Invalid claim "aud"', ) def test_login_default_sub(self) -> None: @@ -1010,7 +1012,7 @@ class JWTTestCase(unittest.HomeserverTestCase): # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # RSS256, with a public key configured in synapse as "jwt_secret", and tokens # signed by the private key. -@skip_unless(HAS_JWT, "requires jwt") +@skip_unless(HAS_JWT, "requires authlib") class JWTPubKeyTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, @@ -1071,11 +1073,11 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): return config def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") - if isinstance(result, bytes): - return result.decode("ascii") - return result + header = {"alg": "RS256"} + if secret.startswith("-----BEGIN RSA PRIVATE KEY-----"): + secret = jwk.dumps(secret, kty="RSA") + result: bytes = jwt.encode(header, payload, secret) + return result.decode("ascii") def jwt_login(self, *args: Any) -> FakeChannel: params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} @@ -1244,7 +1246,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ) # that should redirect to the username picker - self.assertEqual(channel.code, 302, channel.result) + self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers picker_url = location_headers[0] @@ -1288,7 +1290,7 @@ class UsernamePickerTestCase(HomeserverTestCase): ("Content-Length", str(len(content))), ], ) - self.assertEqual(chan.code, 302, chan.result) + self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) location_headers = chan.headers.getRawHeaders("Location") assert location_headers @@ -1298,7 +1300,7 @@ class UsernamePickerTestCase(HomeserverTestCase): path=location_headers[0], custom_headers=[("Cookie", "username_mapping_session=" + session_id)], ) - self.assertEqual(chan.code, 302, chan.result) + self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result) location_headers = chan.headers.getRawHeaders("Location") assert location_headers @@ -1323,5 +1325,5 @@ class UsernamePickerTestCase(HomeserverTestCase): "/login", content={"type": "m.login.token", "token": login_token}, ) - self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.code, HTTPStatus.OK, chan.result) self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py index 3a74d2e96c..e19d21d6ee 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from http import HTTPStatus from twisted.test.proto_helpers import MemoryReactor @@ -89,7 +88,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_too_short(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "shorty"}) + request_data = {"username": "kermit", "password": "shorty"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -100,7 +99,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_digit(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + request_data = {"username": "kermit", "password": "longerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -111,7 +110,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_symbol(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + request_data = {"username": "kermit", "password": "l0ngerpassword"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -122,7 +121,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_uppercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + request_data = {"username": "kermit", "password": "l0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -133,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_no_lowercase(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"} channel = self.make_request("POST", self.register_url, request_data) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) @@ -144,7 +143,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): ) def test_password_compliant(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + request_data = {"username": "kermit", "password": "L0ngerpassword!"} channel = self.make_request("POST", self.register_url, request_data) # Getting a 401 here means the password has passed validation and the server has @@ -161,16 +160,14 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase): user_id = self.register_user("kermit", compliant_password) tok = self.login("kermit", compliant_password) - request_data = json.dumps( - { - "new_password": not_compliant_password, - "auth": { - "password": compliant_password, - "type": LoginType.PASSWORD, - "user": user_id, - }, - } - ) + request_data = { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } channel = self.make_request( "POST", "/_matrix/client/r0/account/password", diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 29bed0e872..8de5a342ae 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -153,18 +153,22 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def _get_displayname(self, name: Optional[str] = None) -> str: + def _get_displayname(self, name: Optional[str] = None) -> Optional[str]: channel = self.make_request( "GET", "/profile/%s/displayname" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) - return channel.json_body["displayname"] + # FIXME: If a user has no displayname set, Synapse returns 200 and omits a + # displayname from the response. This contradicts the spec, see #13137. + return channel.json_body.get("displayname") - def _get_avatar_url(self, name: Optional[str] = None) -> str: + def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]: channel = self.make_request( "GET", "/profile/%s/avatar_url" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) + # FIXME: If a user has no avatar set, Synapse returns 200 and omits an + # avatar_url from the response. This contradicts the spec, see #13137. return channel.json_body.get("avatar_url") @unittest.override_config({"max_avatar_size": 50}) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index afb08b2736..071b488cc0 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime -import json import os from typing import Any, Dict, List, Tuple @@ -62,9 +61,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps( - {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = { + "username": "as_user_kermit", + "type": APP_SERVICE_REGISTRATION_TYPE, + } channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data @@ -85,7 +85,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastores().main.services_cache.append(appservice) - request_data = json.dumps({"username": "as_user_kermit"}) + request_data = {"username": "as_user_kermit"} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data @@ -95,9 +95,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists - request_data = json.dumps( - {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) + request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) @@ -105,14 +103,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"401", channel.result) def test_POST_bad_password(self) -> None: - request_data = json.dumps({"username": "kermit", "password": 666}) + request_data = {"username": "kermit", "password": 666} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self) -> None: - request_data = json.dumps({"username": 777, "password": "monkey"}) + request_data = {"username": 777, "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"400", channel.result) @@ -121,13 +119,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_user_valid(self) -> None: user_id = "@kermit:test" device_id = "frogfone" - params = { + request_data = { "username": "kermit", "password": "monkey", "device_id": device_id, "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) det_data = { @@ -140,7 +137,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"enable_registration": False}) def test_POST_disabled_registration(self) -> None: - request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request_data = {"username": "kermit", "password": "monkey"} self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) channel = self.make_request(b"POST", self.url, request_data) @@ -188,13 +185,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self) -> None: for i in range(0, 6): - params = { + request_data = { "username": "kermit" + str(i), "password": "monkey", "device_id": "frogfone", "auth": {"type": LoginType.DUMMY}, } - request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) if i == 5: @@ -234,7 +230,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): } # Request without auth to get flows and session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one @@ -251,8 +247,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -262,8 +257,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.DUMMY, "session": session, } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) + channel = self.make_request(b"POST", self.url, params) det_data = { "user_id": f"@{username}:{self.hs.hostname}", "home_server": self.hs.hostname, @@ -290,7 +284,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "password": "monkey", } # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Test with token param missing (invalid) @@ -298,21 +292,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "type": LoginType.REGISTRATION_TOKEN, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -337,9 +331,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} # Do 2 requests without auth to get two session IDs - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with session1 and check `pending` is 1 @@ -348,9 +342,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Repeat request to make sure pending isn't increased again - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) pending = self.get_success( store.db_pool.simple_select_one_onecol( "registration_tokens", @@ -366,14 +360,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session2, } - channel = self.make_request(b"POST", self.url, json.dumps(params2)) + channel = self.make_request(b"POST", self.url, params2) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check pending=0 and completed=1 res = self.get_success( store.db_pool.simple_select_one( @@ -386,7 +380,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(res["completed"], 1) # Check auth still fails when using token with session2 - channel = self.make_request(b"POST", self.url, json.dumps(params2)) + channel = self.make_request(b"POST", self.url, params2) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -411,7 +405,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) params: JsonDict = {"username": "kermit", "password": "monkey"} # Request without auth to get session - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Check authentication fails with expired token @@ -420,7 +414,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) self.assertEqual(channel.result["code"], b"401", channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -435,7 +429,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) # Check authentication succeeds - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -460,9 +454,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do 2 requests without auth to get two session IDs params1: JsonDict = {"username": "bert", "password": "monkey"} params2: JsonDict = {"username": "ernie", "password": "monkey"} - channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + channel1 = self.make_request(b"POST", self.url, params1) session1 = channel1.json_body["session"] - channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + channel2 = self.make_request(b"POST", self.url, params2) session2 = channel2.json_body["session"] # Use token with both sessions @@ -471,18 +465,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session1, } - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) params2["auth"] = { "type": LoginType.REGISTRATION_TOKEN, "token": token, "session": session2, } - self.make_request(b"POST", self.url, json.dumps(params2)) + self.make_request(b"POST", self.url, params2) # Complete registration with session1 params1["auth"]["type"] = LoginType.DUMMY - self.make_request(b"POST", self.url, json.dumps(params1)) + self.make_request(b"POST", self.url, params1) # Check `result` of registration token stage for session1 is `True` result1 = self.get_success( @@ -550,7 +544,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): # Do request without auth to get a session ID params: JsonDict = {"username": "kermit", "password": "monkey"} - channel = self.make_request(b"POST", self.url, json.dumps(params)) + channel = self.make_request(b"POST", self.url, params) session = channel.json_body["session"] # Use token @@ -559,7 +553,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "token": token, "session": session, } - self.make_request(b"POST", self.url, json.dumps(params)) + self.make_request(b"POST", self.url, params) # Delete token self.get_success( @@ -592,9 +586,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "require_at_registration": True, }, "account_threepid_delegates": { - "email": "https://id_server", "msisdn": "https://id_server", }, + "email": {"notif_from": "Synapse <synapse@example.com>"}, } ) def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: @@ -827,8 +821,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = {"user_id": user_id} - request_data = json.dumps(params) + request_data = {"user_id": user_id} channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEqual(channel.result["code"], b"200", channel.result) @@ -845,12 +838,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEqual(channel.result["code"], b"200", channel.result) @@ -870,12 +862,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_synapse/admin/v1/account_validity/validity" - params = { + request_data = { "user_id": user_id, "expiration_ts": 0, "enable_renewal_emails": False, } - request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) self.assertEqual(channel.result["code"], b"200", channel.result) @@ -1041,16 +1032,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): (user_id, tok) = self.create_user() - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - } - ) + request_data = { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index aa84906548..ad03eee17b 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -800,7 +800,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): ) expected_event_ids.append(channel.json_body["event_id"]) - prev_token = "" + prev_token: Optional[str] = "" found_event_ids: List[str] = [] for _ in range(20): from_token = "" diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py index 20a259fc43..ad0d0209f7 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json - from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -77,10 +75,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase): def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( - "POST", - self.report_path, - json.dumps(data), - access_token=self.other_user_tok, + "POST", self.report_path, data, access_token=self.other_user_tok ) self.assertEqual( response_status, int(channel.result["code"]), msg=channel.result["body"] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 35c59ee9e0..c45cb32090 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,11 +18,12 @@ """Tests REST events for /rooms paths.""" import json -from typing import Any, Dict, Iterable, List, Optional, Union +from http import HTTPStatus +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from unittest.mock import Mock, call from urllib import parse as urlparse -# `Literal` appears with Python 3.8. +from parameterized import param, parameterized from typing_extensions import Literal from twisted.test.proto_helpers import MemoryReactor @@ -33,7 +34,9 @@ from synapse.api.constants import ( EventContentFields, EventTypes, Membership, + PublicRoomsFilterFields, RelationTypes, + RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus @@ -102,7 +105,7 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # set topic for public room channel = self.make_request( @@ -110,7 +113,7 @@ class RoomPermissionsTestCase(RoomBase): ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), b'{"topic":"Public Room Topic"}', ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # auth as user_id now self.helper.auth_user_id = self.user_id @@ -132,28 +135,28 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), msg_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room not joined (no state), expect 403 channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # send message in created room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # send message in created room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_topic_perms(self) -> None: topic_content = b'{"topic":"My Topic Name"}' @@ -163,28 +166,28 @@ class RoomPermissionsTestCase(RoomBase): channel = self.make_request( "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room not joined, expect 403 channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in created PRIVATE room and invited, expect 403 self.helper.invite( room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id ) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # get topic in created PRIVATE room and invited, expect 403 channel = self.make_request("GET", topic_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set/get topic in created PRIVATE room and joined, expect 200 self.helper.join(room=self.created_rmid, user=self.user_id) @@ -192,25 +195,25 @@ class RoomPermissionsTestCase(RoomBase): # Only room ops can set topic by default self.helper.auth_user_id = self.rmcreator_id channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.auth_user_id = self.user_id channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) # set/get topic in created PRIVATE room and left, expect 403 self.helper.leave(room=self.created_rmid, user=self.user_id) channel = self.make_request("PUT", topic_path, topic_content) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", topic_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # get topic in PUBLIC room, not joined, expect 403 channel = self.make_request( "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) # set topic in PUBLIC room, not joined, expect 403 channel = self.make_request( @@ -218,7 +221,7 @@ class RoomPermissionsTestCase(RoomBase): "/rooms/%s/state/m.room.topic" % self.created_public_rmid, topic_content, ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def _test_get_membership( self, room: str, members: Iterable = frozenset(), expect_code: int = 200 @@ -307,14 +310,14 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( room=room, src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) def test_joined_permissions(self) -> None: @@ -340,7 +343,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of other, expect 403 @@ -349,7 +352,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # set left of self, expect 200 @@ -369,7 +372,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.INVITE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) self.helper.change_membership( @@ -377,7 +380,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=usr, membership=Membership.JOIN, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # It is always valid to LEAVE if you've already left (currently.) @@ -386,7 +389,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=self.rmcreator_id, membership=Membership.LEAVE, - expect_code=403, + expect_code=HTTPStatus.FORBIDDEN, ) # tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember @@ -403,7 +406,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.BAN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -413,7 +416,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to invite: Must never happen. @@ -422,7 +425,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.INVITE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -432,7 +435,7 @@ class RoomPermissionsTestCase(RoomBase): src=other, targ=other, membership=Membership.JOIN, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -442,7 +445,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.BAN, - expect_code=200, + expect_code=HTTPStatus.OK, ) # from ban to knock: Must never happen. @@ -451,7 +454,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.KNOCK, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.BAD_STATE, ) @@ -461,7 +464,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.user_id, targ=other, membership=Membership.LEAVE, - expect_code=403, # expect failure + expect_code=HTTPStatus.FORBIDDEN, # expect failure expect_errcode=Codes.FORBIDDEN, ) @@ -471,7 +474,7 @@ class RoomPermissionsTestCase(RoomBase): src=self.rmcreator_id, targ=other, membership=Membership.LEAVE, - expect_code=200, + expect_code=HTTPStatus.OK, ) @@ -491,7 +494,7 @@ class RoomStateTestCase(RoomBase): "/rooms/%s/state" % room_id, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertCountEqual( [state_event["type"] for state_event in channel.json_body], { @@ -514,7 +517,7 @@ class RoomStateTestCase(RoomBase): "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(channel.json_body, {"membership": "join"}) @@ -528,16 +531,16 @@ class RoomsMemberListTestCase(RoomBase): def test_get_member_list(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_get_member_list_no_room(self) -> None: channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission(self) -> None: room_id = self.helper.create_room_as("@some_other_guy:red") channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_with_at_token(self) -> None: """ @@ -548,7 +551,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that permission is denied for @sid1:red to get the @@ -557,7 +560,7 @@ class RoomsMemberListTestCase(RoomBase): "GET", f"/rooms/{room_id}/members?at={sync_token}", ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member(self) -> None: """ @@ -570,14 +573,14 @@ class RoomsMemberListTestCase(RoomBase): # check that the user can see the member list to start with channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") # check the user can no longer see the member list channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_no_permission_former_member_with_at_token(self) -> None: """ @@ -591,14 +594,14 @@ class RoomsMemberListTestCase(RoomBase): # sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check that the user can see the member list to start with channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # ban the user (Note: the user is actually allowed to see this event and # state so that they know they're banned!) @@ -610,14 +613,14 @@ class RoomsMemberListTestCase(RoomBase): # now, with the original user, sync again to get a new at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] # check the user can no longer see the updated member list channel = self.make_request( "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) ) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) def test_get_member_list_mixed_memberships(self) -> None: room_creator = "@some_other_guy:red" @@ -626,17 +629,17 @@ class RoomsMemberListTestCase(RoomBase): self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) # can't see list if you're just invited. channel = self.make_request("GET", room_path) - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.helper.join(room=room_id, user=self.user_id) # can see list now joined channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.helper.leave(room=room_id, user=self.user_id) # can see old list once left channel = self.make_request("GET", room_path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_get_member_list_cancellation(self) -> None: """Test cancellation of a `/rooms/$room_id/members` request.""" @@ -649,7 +652,7 @@ class RoomsMemberListTestCase(RoomBase): "/rooms/%s/members" % room_id, ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertLessEqual( { @@ -669,7 +672,7 @@ class RoomsMemberListTestCase(RoomBase): # first sync to get an at token channel = self.make_request("GET", "/sync") - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) sync_token = channel.json_body["next_batch"] channel = make_request_with_cancellation_test( @@ -680,7 +683,7 @@ class RoomsMemberListTestCase(RoomBase): "/rooms/%s/members?at=%s" % (room_id, sync_token), ) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(len(channel.json_body["chunk"]), 1) self.assertLessEqual( { @@ -704,19 +707,34 @@ class RoomsCreateTestCase(RoomBase): # POST with no config keys, expect new room id channel = self.make_request("POST", "/createRoom", "{}") - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(44, channel.resource_usage.db_txn_count) + + def test_post_room_initial_state(self) -> None: + # POST with initial_state config key, expect new room id + channel = self.make_request( + "POST", + "/createRoom", + b'{"initial_state":[{"type": "m.bridge", "content": {}}]}', + ) + + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertTrue("room_id" in channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(50, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_custom_key(self) -> None: # POST with custom config keys, expect new room id channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_known_and_unknown_keys(self) -> None: @@ -724,16 +742,16 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("room_id" in channel.json_body) def test_post_room_invalid_content(self) -> None: # POST with invalid content / paths, expect 400 channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) def test_post_room_invitees_invalid_mxid(self) -> None: # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 @@ -741,7 +759,7 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request( "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' ) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) def test_post_room_invitees_ratelimit(self) -> None: @@ -752,20 +770,18 @@ class RoomsCreateTestCase(RoomBase): # Build the request's content. We use local MXIDs because invites over federation # are more difficult to mock. - content = json.dumps( - { - "invite": [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - } - ).encode("utf8") + content = { + "invite": [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + } # Test that the invites are correctly ratelimited. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(400, channel.code) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code) self.assertEqual( "Cannot invite so many users at once", channel.json_body["error"], @@ -778,7 +794,7 @@ class RoomsCreateTestCase(RoomBase): # Test that the invites aren't ratelimited anymore. channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed @@ -802,7 +818,7 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) @@ -813,14 +829,14 @@ class RoomsCreateTestCase(RoomBase): In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. """ - async def user_may_join_room( + async def user_may_join_room_codes( mxid: str, room_id: str, is_invite: bool, ) -> Codes: return Codes.CONSENT_NOT_GIVEN - join_mock = Mock(side_effect=user_may_join_room) + join_mock = Mock(side_effect=user_may_join_room_codes) self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) channel = self.make_request( @@ -828,10 +844,29 @@ class RoomsCreateTestCase(RoomBase): "/createRoom", {}, ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) self.assertEqual(join_mock.call_count, 0) + # Now change the return value of the callback to deny any join. Since we're + # creating the room, despite the return value, we should be able to join. + async def user_may_join_room_tuple( + mxid: str, + room_id: str, + is_invite: bool, + ) -> Tuple[Codes, dict]: + return Codes.INCOMPATIBLE_ROOM_VERSION, {} + + join_mock.side_effect = user_may_join_room_tuple + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual(join_mock.call_count, 0) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -846,54 +881,68 @@ class RoomTopicTestCase(RoomBase): def test_invalid_puts(self) -> None: # missing keys or invalid json channel = self.make_request("PUT", self.path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request( "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' ) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", self.path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid key, wrong type content = '{"topic":["Topic name"]}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_topic(self) -> None: # nothing should be there channel = self.make_request("GET", self.path) - self.assertEqual(404, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.result["body"]) # valid put content = '{"topic":"Topic name"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) def test_rooms_topic_with_extra_keys(self) -> None: # valid put with extra keys content = '{"topic":"Seasons","subtopic":"Summer"}' channel = self.make_request("PUT", self.path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # valid get channel = self.make_request("GET", self.path) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assert_dict(json.loads(content), channel.json_body) @@ -909,22 +958,34 @@ class RoomMemberStateTestCase(RoomBase): path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) # missing keys or invalid json channel = self.make_request("PUT", path, "{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, '{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, "") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # valid keys, wrong types content = '{"membership":["%s","%s","%s"]}' % ( @@ -933,7 +994,9 @@ class RoomMemberStateTestCase(RoomBase): Membership.LEAVE, ) channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_members_self(self) -> None: path = "/rooms/%s/state/m.room.member/%s" % ( @@ -944,10 +1007,10 @@ class RoomMemberStateTestCase(RoomBase): # valid join message (NOOP since we made the room) content = '{"membership":"%s"}' % Membership.JOIN channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) expected_response = {"membership": Membership.JOIN} self.assertEqual(expected_response, channel.json_body) @@ -962,10 +1025,10 @@ class RoomMemberStateTestCase(RoomBase): # valid invite message content = '{"membership":"%s"}' % Membership.INVITE channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) def test_rooms_members_other_custom_keys(self) -> None: @@ -981,10 +1044,10 @@ class RoomMemberStateTestCase(RoomBase): "Join us!", ) channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) channel = self.make_request("GET", path, content=b"") - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(json.loads(content), channel.json_body) @@ -1101,7 +1164,9 @@ class RoomJoinTestCase(RoomBase): # Now make the callback deny all room joins, and check that a join actually fails. return_value = False - self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + self.helper.join( + self.room3, self.user2, expect_code=HTTPStatus.FORBIDDEN, tok=self.tok2 + ) def test_spam_checker_may_join_room(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called @@ -1111,13 +1176,15 @@ class RoomJoinTestCase(RoomBase): """ # Register a dummy callback. Make it allow all room joins for now. - return_value: Union[Literal["NOT_SPAM"], Codes] = synapse.module_api.NOT_SPAM + return_value: Union[ + Literal["NOT_SPAM"], Tuple[Codes, dict], Codes + ] = synapse.module_api.NOT_SPAM async def user_may_join_room( userid: str, room_id: str, is_invited: bool, - ) -> Union[Literal["NOT_SPAM"], Codes]: + ) -> Union[Literal["NOT_SPAM"], Tuple[Codes, dict], Codes]: return return_value # `spec` argument is needed for this function mock to have `__qualname__`, which @@ -1161,8 +1228,28 @@ class RoomJoinTestCase(RoomBase): ) # Now make the callback deny all room joins, and check that a join actually fails. + # We pick an arbitrary Codes rather than the default `Codes.FORBIDDEN`. return_value = Codes.CONSENT_NOT_GIVEN - self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + self.helper.invite(self.room3, self.user1, self.user2, tok=self.tok1) + self.helper.join( + self.room3, + self.user2, + expect_code=HTTPStatus.FORBIDDEN, + expect_errcode=return_value, + tok=self.tok2, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + # As above, with the experimental extension that lets us return dictionaries. + return_value = (Codes.BAD_ALIAS, {"another_field": "12345"}) + self.helper.join( + self.room3, + self.user2, + expect_code=HTTPStatus.FORBIDDEN, + expect_errcode=return_value[0], + tok=self.tok2, + expect_additional_fields=return_value[1], + ) class RoomJoinRatelimitTestCase(RoomBase): @@ -1212,7 +1299,7 @@ class RoomJoinRatelimitTestCase(RoomBase): # Update the display name for the user. path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) # Check that all the rooms have been sent a profile update into. for room_id in room_ids: @@ -1277,40 +1364,153 @@ class RoomMessagesTestCase(RoomBase): path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) # missing keys or invalid json channel = self.make_request("PUT", path, b"{}") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'{"nao') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"text only") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) channel = self.make_request("PUT", path, b"") - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) def test_rooms_messages_sent(self) -> None: path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) content = b'{"body":"test","msgtype":{"type":"a"}}' channel = self.make_request("PUT", path, content) - self.assertEqual(400, channel.code, msg=channel.result["body"]) + self.assertEqual( + HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"] + ) # custom message types content = b'{"body":"test","msgtype":"test.custom.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) content = b'{"body":"test2","msgtype":"m.text"}' channel = self.make_request("PUT", path, content) - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) + + @parameterized.expand( + [ + # Allow + param( + name="NOT_SPAM", + value="NOT_SPAM", + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + param( + name="False", + value=False, + expected_code=HTTPStatus.OK, + expected_fields={}, + ), + # Block + param( + name="scalene string", + value="ANY OTHER STRING", + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_FORBIDDEN"}, + ), + param( + name="True", + value=True, + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_FORBIDDEN"}, + ), + param( + name="Code", + value=Codes.LIMIT_EXCEEDED, + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={"errcode": "M_LIMIT_EXCEEDED"}, + ), + param( + name="Tuple", + value=(Codes.SERVER_NOT_TRUSTED, {"additional_field": "12345"}), + expected_code=HTTPStatus.FORBIDDEN, + expected_fields={ + "errcode": "M_SERVER_NOT_TRUSTED", + "additional_field": "12345", + }, + ), + ] + ) + def test_spam_checker_check_event_for_spam( + self, + name: str, + value: Union[str, bool, Codes, Tuple[Codes, JsonDict]], + expected_code: int, + expected_fields: dict, + ) -> None: + class SpamCheck: + mock_return_value: Union[ + str, bool, Codes, Tuple[Codes, JsonDict], bool + ] = "NOT_SPAM" + mock_content: Optional[JsonDict] = None + + async def check_event_for_spam( + self, + event: synapse.events.EventBase, + ) -> Union[str, Codes, Tuple[Codes, JsonDict], bool]: + self.mock_content = event.content + return self.mock_return_value + + spam_checker = SpamCheck() + + self.hs.get_spam_checker()._check_event_for_spam_callbacks.append( + spam_checker.check_event_for_spam + ) + + # Inject `value` as mock_return_value + spam_checker.mock_return_value = value + path = "/rooms/%s/send/m.room.message/check_event_for_spam_%s" % ( + urlparse.quote(self.room_id), + urlparse.quote(name), + ) + body = "test-%s" % name + content = '{"body":"%s","msgtype":"m.text"}' % body + channel = self.make_request("PUT", path, content) + + # Check that the callback has witnessed the correct event. + self.assertIsNotNone(spam_checker.mock_content) + if ( + spam_checker.mock_content is not None + ): # Checked just above, but mypy doesn't know about that. + self.assertEqual( + spam_checker.mock_content["body"], body, spam_checker.mock_content + ) + + # Check that we have the correct result. + self.assertEqual(expected_code, channel.code, msg=channel.result["body"]) + for expected_key, expected_value in expected_fields.items(): + self.assertEqual( + channel.json_body.get(expected_key, None), + expected_value, + "Field %s absent or invalid " % expected_key, + ) class RoomPowerLevelOverridesTestCase(RoomBase): @@ -1435,7 +1635,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) def test_normal_user_can_not_post_state_event(self) -> None: # Given I am a normal member of a room @@ -1449,7 +1649,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed because state events require PL>=50 - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " "user_level (0) < send_level (50)", @@ -1476,7 +1676,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am allowed - self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1504,7 +1704,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) @unittest.override_config( { @@ -1532,7 +1732,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel = self.make_request("PUT", path, "{}") # Then I am not allowed - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (1)", @@ -1563,7 +1763,7 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): # Then I am not allowed because the public_chat config does not # affect this room, because this room is a private_chat - self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) self.assertEqual( "You don't have permission to post that to the room. " + "user_level (0) < send_level (50)", @@ -1582,7 +1782,7 @@ class RoomInitialSyncTestCase(RoomBase): def test_initial_sync(self) -> None: channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertEqual(self.room_id, channel.json_body["room_id"]) self.assertEqual("join", channel.json_body["membership"]) @@ -1625,7 +1825,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1636,7 +1836,7 @@ class RoomMessageListTestCase(RoomBase): channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) - self.assertEqual(200, channel.code) + self.assertEqual(HTTPStatus.OK, channel.code) self.assertTrue("start" in channel.json_body) self.assertEqual(token, channel.json_body["start"]) self.assertTrue("chunk" in channel.json_body) @@ -1675,7 +1875,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) @@ -1703,7 +1903,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) @@ -1720,7 +1920,7 @@ class RoomMessageListTestCase(RoomBase): json.dumps({"types": [EventTypes.Message]}), ), ) - self.assertEqual(channel.code, 200, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) chunk = channel.json_body["chunk"] self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) @@ -1848,14 +2048,98 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): def test_restricted_no_auth(self) -> None: channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, 401, channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result) def test_restricted_auth(self) -> None: self.register_user("user", "pass") tok = self.login("user", "pass") channel = self.make_request("GET", self.url, access_token=tok) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) + + +class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + + config = self.default_config() + config["allow_public_rooms_without_auth"] = True + config["experimental_features"] = {"msc3827_enabled": True} + self.hs = self.setup_test_homeserver(config=config) + self.url = b"/_matrix/client/r0/publicRooms" + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + user = self.register_user("alice", "pass") + self.token = self.login(user, "pass") + + # Create a room + self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=self.token, + ) + # Create a space + self.helper.create_room_as( + user, + is_public=True, + extra_content={ + "visibility": "public", + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}, + }, + tok=self.token, + ) + + def make_public_rooms_request( + self, room_types: Union[List[Union[str, None]], None] + ) -> Tuple[List[Dict[str, Any]], int]: + channel = self.make_request( + "POST", + self.url, + {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}, + self.token, + ) + chunk = channel.json_body["chunk"] + count = channel.json_body["total_room_count_estimate"] + + self.assertEqual(len(chunk), count) + + return chunk, count + + def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None: + chunk, count = self.make_public_rooms_request(None) + + self.assertEqual(count, 2) + + def test_returns_only_rooms_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request([None]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), None) + + def test_returns_only_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space"]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), "m.space") + + def test_returns_both_rooms_and_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space", None]) + + self.assertEqual(count, 2) + + def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None: + chunk, count = self.make_public_rooms_request([]) + + self.assertEqual(count, 2) class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): @@ -1882,7 +2166,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): "Simple test for searching rooms over federation" self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1890,7 +2174,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined] "testserv", @@ -1907,11 +2191,11 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # The `get_public_rooms` should be called again if the first call fails # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] - HttpResponseException(404, "Not Found", b""), + HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), make_awaitable({}), ) - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1919,7 +2203,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): content={"filter": search_filter}, access_token=self.token, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined] [ @@ -1965,21 +2249,19 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): # Set a profile for the test user self.displayname = "test user" - data = {"displayname": self.displayname} - request_data = json.dumps(data) + request_data = {"displayname": self.displayname} channel = self.make_request( "PUT", "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) def test_per_room_profile_forbidden(self) -> None: - data = {"membership": "join", "displayname": "other test user"} - request_data = json.dumps(data) + request_data = {"membership": "join", "displayname": "other test user"} channel = self.make_request( "PUT", "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" @@ -1987,7 +2269,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): request_data, access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1995,7 +2277,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) @@ -2029,7 +2311,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2043,7 +2325,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2057,7 +2339,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2071,7 +2353,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2083,7 +2365,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2095,7 +2377,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason, "user_id": self.second_user_id}, access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2114,7 +2396,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): content={"reason": reason}, access_token=self.second_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self._check_for_reason(reason) @@ -2126,7 +2408,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): ), access_token=self.creator_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) event_content = channel.json_body @@ -2174,7 +2456,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2204,7 +2486,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2239,7 +2521,7 @@ class LabelsTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2319,16 +2601,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_labels(self) -> None: """Test that we can filter by a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2356,16 +2636,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): def test_search_filter_not_labels(self) -> None: """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2405,16 +2683,14 @@ class LabelsTestCase(unittest.HomeserverTestCase): """Test that we can filter by both a label and the absence of another label on a /search request. """ - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, - } + request_data = { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, } } - ) + } self._send_labelled_messages_in_room() @@ -2587,7 +2863,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) return channel.json_body["chunk"] @@ -2692,7 +2968,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=self.tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2758,7 +3034,7 @@ class ContextTestCase(unittest.HomeserverTestCase): % (self.room_id, event_id), access_token=invited_tok, ) - self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, channel.result) events_before = channel.json_body["events_before"] @@ -2859,8 +3135,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -2889,8 +3164,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None: url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) + request_data = {"room_id": self.room_id} channel = self.make_request( "PUT", url, request_data, access_token=self.room_owner_tok @@ -2916,7 +3190,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - json.dumps(content), + content, access_token=self.room_owner_tok, ) self.assertEqual(channel.code, expected_code, channel.result) @@ -3050,7 +3324,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable(0)) + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_identity_handler().lookup_3pid = Mock( return_value=make_awaitable(None), @@ -3111,7 +3385,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable(0)) + make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock self.hs.get_identity_handler().lookup_3pid = Mock( return_value=make_awaitable(None), @@ -3149,7 +3423,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): make_invite_mock.assert_called_once() # Now change the return value of the callback to deny any invite and test that - # we can't send the invite. + # we can't send the invite. We pick an arbitrary error code to be able to check + # that the same code has been returned mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) channel = self.make_request( method="POST", @@ -3163,6 +3438,27 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): access_token=self.tok, ) self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN) + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() + + # Run variant with `Tuple[Codes, dict]`. + mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"})) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT) + self.assertEqual(channel.json_body["field"], "value") # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index e3efd1f1b0..b085c50356 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -606,11 +606,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8") channel = self.make_request( "POST", f"/rooms/{self.room_id}/read_markers", - body, + {ReceiptTypes.READ: res["event_id"]}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 5eb0f243f7..9a48e9286f 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -21,7 +21,6 @@ from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase -from synapse.events.snapshot import EventContext from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.rest import admin from synapse.rest.client import account, login, profile, room @@ -113,14 +112,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth( - origin: str, - event: EventBase, - context: EventContext, - *args: Any, - **kwargs: Any, - ) -> EventContext: - return context + async def _check_event_auth(origin: Any, event: Any, context: Any) -> None: + pass hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index 98c1039d33..5e7bf97482 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -48,10 +48,14 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.helper.join(self.room_id, self.other, tok=self.other_token) def _upgrade_room( - self, token: Optional[str] = None, room_id: Optional[str] = None + self, + token: Optional[str] = None, + room_id: Optional[str] = None, + expire_cache: bool = True, ) -> FakeChannel: - # We never want a cached response. - self.reactor.advance(5 * 60 + 1) + if expire_cache: + # We don't want a cached response. + self.reactor.advance(5 * 60 + 1) if room_id is None: room_id = self.room_id @@ -72,9 +76,24 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.result) self.assertIn("replacement_room", channel.json_body) - def test_not_in_room(self) -> None: + new_room_id = channel.json_body["replacement_room"] + + # Check that the tombstone event points to the new room. + tombstone_event = self.get_success( + self.hs.get_storage_controllers().state.get_current_state_event( + self.room_id, EventTypes.Tombstone, "" + ) + ) + self.assertIsNotNone(tombstone_event) + self.assertEqual(new_room_id, tombstone_event.content["replacement_room"]) + + # Check that the new room exists. + room = self.get_success(self.store.get_room(new_room_id)) + self.assertIsNotNone(room) + + def test_never_in_room(self) -> None: """ - Upgrading a room should work fine. + A user who has never been in the room cannot upgrade the room. """ # The user isn't in the room. roomless = self.register_user("roomless", "pass") @@ -83,6 +102,16 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): channel = self._upgrade_room(roomless_token) self.assertEqual(403, channel.code, channel.result) + def test_left_room(self) -> None: + """ + A user who is no longer in the room cannot upgrade the room. + """ + # Remove the user from the room. + self.helper.leave(self.room_id, self.creator, tok=self.creator_token) + + channel = self._upgrade_room(self.creator_token) + self.assertEqual(403, channel.code, channel.result) + def test_power_levels(self) -> None: """ Another user can upgrade the room if their power level is increased. @@ -297,3 +326,47 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.assertEqual( create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type ) + + def test_second_upgrade_from_same_user(self) -> None: + """A second room upgrade from the same user is deduplicated.""" + channel1 = self._upgrade_room() + self.assertEqual(200, channel1.code, channel1.result) + + channel2 = self._upgrade_room(expire_cache=False) + self.assertEqual(200, channel2.code, channel2.result) + + self.assertEqual( + channel1.json_body["replacement_room"], + channel2.json_body["replacement_room"], + ) + + def test_second_upgrade_after_delay(self) -> None: + """A second room upgrade is not deduplicated after some time has passed.""" + channel1 = self._upgrade_room() + self.assertEqual(200, channel1.code, channel1.result) + + channel2 = self._upgrade_room(expire_cache=True) + self.assertEqual(200, channel2.code, channel2.result) + + self.assertNotEqual( + channel1.json_body["replacement_room"], + channel2.json_body["replacement_room"], + ) + + def test_second_upgrade_from_different_user(self) -> None: + """A second room upgrade from a different user is blocked.""" + channel = self._upgrade_room() + self.assertEqual(200, channel.code, channel.result) + + channel = self._upgrade_room(self.other_token, expire_cache=False) + self.assertEqual(400, channel.code, channel.result) + + def test_first_upgrade_does_not_block_second(self) -> None: + """A second room upgrade is not blocked when a previous upgrade attempt was not + allowed. + """ + channel = self._upgrade_room(self.other_token) + self.assertEqual(403, channel.code, channel.result) + + channel = self._upgrade_room(expire_cache=False) + self.assertEqual(200, channel.code, channel.result) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index a0788b1bb0..105d418698 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -41,6 +41,7 @@ from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict @@ -135,7 +136,7 @@ class RestHelper: self.site, "POST", path, - json.dumps(content).encode("utf8"), + content, custom_headers=custom_headers, ) @@ -171,6 +172,8 @@ class RestHelper: expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, + expect_errcode: Optional[Codes] = None, + expect_additional_fields: Optional[dict] = None, ) -> None: self.change_membership( room=room, @@ -180,6 +183,8 @@ class RestHelper: appservice_user_id=appservice_user_id, membership=Membership.JOIN, expect_code=expect_code, + expect_errcode=expect_errcode, + expect_additional_fields=expect_additional_fields, ) def knock( @@ -205,7 +210,7 @@ class RestHelper: self.site, "POST", path, - json.dumps(data).encode("utf8"), + data, ) assert ( @@ -263,6 +268,7 @@ class RestHelper: appservice_user_id: Optional[str] = None, expect_code: int = HTTPStatus.OK, expect_errcode: Optional[str] = None, + expect_additional_fields: Optional[dict] = None, ) -> None: """ Send a membership state event into a room. @@ -303,7 +309,7 @@ class RestHelper: self.site, "PUT", path, - json.dumps(data).encode("utf8"), + data, ) assert ( @@ -323,6 +329,21 @@ class RestHelper: channel.result["body"], ) + if expect_additional_fields is not None: + for expect_key, expect_value in expect_additional_fields.items(): + assert expect_key in channel.json_body, "Expected field %s, got %s" % ( + expect_key, + channel.json_body, + ) + assert ( + channel.json_body[expect_key] == expect_value + ), "Expected: %s at %s, got: %s, resp: %s" % ( + expect_value, + expect_key, + channel.json_body[expect_key], + channel.json_body, + ) + self.auth_user_id = temp_id def send( @@ -371,7 +392,7 @@ class RestHelper: self.site, "PUT", path, - json.dumps(content or {}).encode("utf8"), + content or {}, custom_headers=custom_headers, ) diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index ea9e5889bf..1062081a06 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -370,6 +370,64 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) + def test_twitter_tag(self) -> None: + """Twitter card tags should be used if nothing else is available.""" + html = b""" + <html> + <meta name="twitter:card" content="summary"> + <meta name="twitter:description" content="Description"> + <meta name="twitter:site" content="@matrixdotorg"> + </html> + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + self.assertEqual( + og, + { + "og:title": None, + "og:description": "Description", + "og:site_name": "@matrixdotorg", + }, + ) + + # But they shouldn't override Open Graph values. + html = b""" + <html> + <meta name="twitter:card" content="summary"> + <meta name="twitter:description" content="Description"> + <meta property="og:description" content="Real Description"> + <meta name="twitter:site" content="@matrixdotorg"> + <meta property="og:site_name" content="matrix.org"> + </html> + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + self.assertEqual( + og, + { + "og:title": None, + "og:description": "Real Description", + "og:site_name": "matrix.org", + }, + ) + + def test_nested_nodes(self) -> None: + """A body with some nested nodes. Tests that we iterate over children + in the right order (and don't reverse the order of the text).""" + html = b""" + <a href="somewhere">Welcome <b>the bold <u>and underlined text <svg> + with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a> + """ + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + self.assertEqual( + og, + { + "og:title": None, + "og:description": "Welcome\n\nthe bold\n\nand underlined text\n\nand\n\nsome\n\ntail text", + }, + ) + class MediaEncodingTestCase(unittest.TestCase): def test_meta_charset(self) -> None: diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 7204b2dfe0..d18fc13c21 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -23,11 +23,13 @@ from urllib import parse import attr from parameterized import parameterized, parameterized_class from PIL import Image as Image +from typing_extensions import Literal from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor +from synapse.api.errors import Codes from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import make_deferred_yieldable @@ -124,7 +126,9 @@ class _TestImage: expected_scaled: The expected bytes from scaled thumbnailing, or None if test should just check for a valid image returned. expected_found: True if the file should exist on the server, or False if - a 404 is expected. + a 404/400 is expected. + unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or + False if the thumbnailing should succeed or a normal 404 is expected. """ data: bytes @@ -133,6 +137,7 @@ class _TestImage: expected_cropped: Optional[bytes] = None expected_scaled: Optional[bytes] = None expected_found: bool = True + unable_to_thumbnail: bool = False @parameterized_class( @@ -190,6 +195,7 @@ class _TestImage: b"image/gif", b".gif", expected_found=False, + unable_to_thumbnail=True, ), ), ], @@ -364,18 +370,29 @@ class MediaRepoTests(unittest.HomeserverTestCase): def test_thumbnail_crop(self) -> None: """Test that a cropped remote thumbnail is available.""" self._test_thumbnail( - "crop", self.test_image.expected_cropped, self.test_image.expected_found + "crop", + self.test_image.expected_cropped, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) def test_thumbnail_scale(self) -> None: """Test that a scaled remote thumbnail is available.""" self._test_thumbnail( - "scale", self.test_image.expected_scaled, self.test_image.expected_found + "scale", + self.test_image.expected_scaled, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) def test_invalid_type(self) -> None: """An invalid thumbnail type is never available.""" - self._test_thumbnail("invalid", None, False) + self._test_thumbnail( + "invalid", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]} @@ -384,7 +401,12 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ Override the config to generate only scaled thumbnails, but request a cropped one. """ - self._test_thumbnail("crop", None, False) + self._test_thumbnail( + "crop", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) @unittest.override_config( {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]} @@ -393,14 +415,22 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ Override the config to generate only cropped thumbnails, but request a scaled one. """ - self._test_thumbnail("scale", None, False) + self._test_thumbnail( + "scale", + None, + expected_found=False, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, + ) def test_thumbnail_repeated_thumbnail(self) -> None: """Test that fetching the same thumbnail works, and deleting the on disk thumbnail regenerates it. """ self._test_thumbnail( - "scale", self.test_image.expected_scaled, self.test_image.expected_found + "scale", + self.test_image.expected_scaled, + expected_found=self.test_image.expected_found, + unable_to_thumbnail=self.test_image.unable_to_thumbnail, ) if not self.test_image.expected_found: @@ -457,8 +487,24 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) def _test_thumbnail( - self, method: str, expected_body: Optional[bytes], expected_found: bool + self, + method: str, + expected_body: Optional[bytes], + expected_found: bool, + unable_to_thumbnail: bool = False, ) -> None: + """Test the given thumbnailing method works as expected. + + Args: + method: The thumbnailing method to use (crop, scale). + expected_body: The expected bytes from thumbnailing, or None if + test should just check for a valid image. + expected_found: True if the file should exist on the server, or False if + a 404/400 is expected. + unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or + False if the thumbnailing should succeed or a normal 404 is expected. + """ + params = "?width=32&height=32&method=" + method channel = make_request( self.reactor, @@ -481,6 +527,12 @@ class MediaRepoTests(unittest.HomeserverTestCase): if expected_found: self.assertEqual(channel.code, 200) + + self.assertEqual( + channel.headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), + [b"cross-origin"], + ) + if expected_body is not None: self.assertEqual( channel.result["body"], expected_body, channel.result["body"] @@ -488,6 +540,16 @@ class MediaRepoTests(unittest.HomeserverTestCase): else: # ensure that the result is at least some valid image Image.open(BytesIO(channel.result["body"])) + elif unable_to_thumbnail: + # A 400 with a JSON body. + self.assertEqual(channel.code, 400) + self.assertEqual( + channel.json_body, + { + "errcode": "M_UNKNOWN", + "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", + }, + ) else: # A 404 with a JSON body. self.assertEqual(channel.code, 404) @@ -549,10 +611,26 @@ class MediaRepoTests(unittest.HomeserverTestCase): [b"noindex, nofollow, noarchive, noimageindex"], ) + def test_cross_origin_resource_policy_header(self) -> None: + """ + Test that the Cross-Origin-Resource-Policy header is set to "cross-origin" + allowing web clients to embed media from the downloads API. + """ + channel = self._req(b"inline; filename=out" + self.test_image.extension) -class TestSpamChecker: + headers = channel.headers + + self.assertEqual( + headers.getRawHeaders(b"Cross-Origin-Resource-Policy"), + [b"cross-origin"], + ) + + +class TestSpamCheckerLegacy: """A spam checker module that rejects all media that includes the bytes `evil`. + + Uses the legacy Spam-Checker API. """ def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None: @@ -593,7 +671,7 @@ class TestSpamChecker: return b"evil" in buf.getvalue() -class SpamCheckerTestCase(unittest.HomeserverTestCase): +class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase): servlets = [ login.register_servlets, admin.register_servlets, @@ -617,7 +695,8 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): { "spam_checker": [ { - "module": TestSpamChecker.__module__ + ".TestSpamChecker", + "module": TestSpamCheckerLegacy.__module__ + + ".TestSpamCheckerLegacy", "config": {}, } ] @@ -642,3 +721,62 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): self.helper.upload_media( self.upload_resource, data, tok=self.tok, expect_code=400 ) + + +EVIL_DATA = b"Some evil data" +EVIL_DATA_EXPERIMENT = b"Some evil data to trigger the experimental tuple API" + + +class SpamCheckerTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + admin.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = self.register_user("user", "pass") + self.tok = self.login("user", "pass") + + # Allow for uploading and downloading to/from the media repo + self.media_repo = hs.get_media_repository_resource() + self.download_resource = self.media_repo.children[b"download"] + self.upload_resource = self.media_repo.children[b"upload"] + + hs.get_module_api().register_spam_checker_callbacks( + check_media_file_for_spam=self.check_media_file_for_spam + ) + + async def check_media_file_for_spam( + self, file_wrapper: ReadableFileWrapper, file_info: FileInfo + ) -> Union[Codes, Literal["NOT_SPAM"]]: + buf = BytesIO() + await file_wrapper.write_chunks_to(buf.write) + + if buf.getvalue() == EVIL_DATA: + return Codes.FORBIDDEN + elif buf.getvalue() == EVIL_DATA_EXPERIMENT: + return (Codes.FORBIDDEN, {}) + else: + return "NOT_SPAM" + + def test_upload_innocent(self) -> None: + """Attempt to upload some innocent data that should be allowed.""" + self.helper.upload_media( + self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 + ) + + def test_upload_ban(self) -> None: + """Attempt to upload some data that includes bytes "evil", which should + get rejected by the spam checker. + """ + + self.helper.upload_media( + self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400 + ) + + self.helper.upload_media( + self.upload_resource, + EVIL_DATA_EXPERIMENT, + tok=self.tok, + expect_code=400, + ) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 11f78f52b8..d8faafec75 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -59,6 +59,28 @@ class WellKnownTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + @unittest.override_config( + { + "public_baseurl": "https://tesths", + "default_identity_server": "https://testis", + "extra_well_known_client_content": {"custom": False}, + } + ) + def test_client_well_known_custom(self) -> None: + channel = self.make_request( + "GET", "/.well-known/matrix/client", shorthand=False + ) + + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual( + channel.json_body, + { + "m.homeserver": {"base_url": "https://tesths/"}, + "m.identity_server": {"base_url": "https://testis"}, + "custom": False, + }, + ) + @unittest.override_config({"serve_server_wellknown": True}) def test_server_well_known(self) -> None: channel = self.make_request( diff --git a/tests/server.py b/tests/server.py index b9f465971f..df3f1564c9 100644 --- a/tests/server.py +++ b/tests/server.py @@ -43,6 +43,7 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( IAddress, + IConsumer, IHostnameResolver, IProtocol, IPullProducer, @@ -53,11 +54,7 @@ from twisted.internet.interfaces import ( ITransport, ) from twisted.python.failure import Failure -from twisted.test.proto_helpers import ( - AccumulatingProtocol, - MemoryReactor, - MemoryReactorClock, -) +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site @@ -96,6 +93,7 @@ class TimedOutException(Exception): """ +@implementer(IConsumer) @attr.s(auto_attribs=True) class FakeChannel: """ @@ -104,7 +102,7 @@ class FakeChannel: """ site: Union[Site, "FakeSite"] - _reactor: MemoryReactor + _reactor: MemoryReactorClock result: dict = attr.Factory(dict) _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None @@ -122,7 +120,7 @@ class FakeChannel: self._request = request @property - def json_body(self): + def json_body(self) -> JsonDict: return json.loads(self.text_body) @property @@ -140,7 +138,7 @@ class FakeChannel: return self.result.get("done", False) @property - def code(self): + def code(self) -> int: if not self.result: raise Exception("No result yet.") return int(self.result["code"]) @@ -160,7 +158,7 @@ class FakeChannel: self.result["reason"] = reason self.result["headers"] = headers - def write(self, content): + def write(self, content: bytes) -> None: assert isinstance(content, bytes), "Should be bytes! " + repr(content) if "body" not in self.result: @@ -168,11 +166,16 @@ class FakeChannel: self.result["body"] += content - def registerProducer(self, producer, streaming): + # Type ignore: mypy doesn't like the fact that producer isn't an IProducer. + def registerProducer( # type: ignore[override] + self, + producer: Union[IPullProducer, IPushProducer], + streaming: bool, + ) -> None: self._producer = producer self.producerStreaming = streaming - def _produce(): + def _produce() -> None: if self._producer: self._producer.resumeProducing() self._reactor.callLater(0.1, _produce) @@ -180,31 +183,32 @@ class FakeChannel: if not streaming: self._reactor.callLater(0.0, _produce) - def unregisterProducer(self): + def unregisterProducer(self) -> None: if self._producer is None: return self._producer = None - def requestDone(self, _self): + def requestDone(self, _self: Request) -> None: self.result["done"] = True if isinstance(_self, SynapseRequest): + assert _self.logcontext is not None self.resource_usage = _self.logcontext.get_resource_usage() - def getPeer(self): + def getPeer(self) -> IAddress: # We give an address so that getClientAddress/getClientIP returns a non null entry, # causing us to record the MAU return address.IPv4Address("TCP", self._ip, 3423) - def getHost(self): + def getHost(self) -> IAddress: # this is called by Request.__init__ to configure Request.host. return address.IPv4Address("TCP", "127.0.0.1", 8888) - def isSecure(self): + def isSecure(self) -> bool: return False @property - def transport(self): + def transport(self) -> "FakeChannel": return self def await_result(self, timeout_ms: int = 1000) -> None: @@ -830,7 +834,6 @@ def setup_test_homeserver( # Mock TLS hs.tls_server_context_factory = Mock() - hs.tls_client_options_factory = Mock() hs.setup() if homeserver_to_use == TestHomeServer: diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 38963ce4a7..46d829b062 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -143,7 +143,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) def test_simple(self): """Test that we cache events that we pull from the DB.""" @@ -160,7 +160,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): """ # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: # We keep hold of the event event though we never use it. @@ -170,7 +170,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: self.get_success(self.store.get_event(self.event_id)) @@ -345,7 +345,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) @contextmanager def blocking_get_event_calls( diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 9abd0cb446..1edb619630 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + +from synapse.api.constants import RoomTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.storage.databases.main.room import _BackgroundUpdates @@ -91,3 +94,69 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ) ) self.assertEqual(room_creator_after, self.user_id) + + def test_background_add_room_type_column(self): + """Test that the background update to populate the `room_type` column in + `room_stats_state` works properly. + """ + + # Create a room without a type + room_id = self._generate_room() + + # Get event_id of the m.room.create event + event_id = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="current_state_events", + keyvalues={ + "room_id": room_id, + "type": "m.room.create", + }, + retcol="event_id", + ) + ) + + # Fake a room creation event with a room type + event = { + "content": { + "creator": "@user:server.org", + "room_version": "9", + "type": RoomTypes.SPACE, + }, + "type": "m.room.create", + } + self.get_success( + self.store.db_pool.simple_update( + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": json.dumps(event)}, + desc="test", + ) + ) + + # Insert and run the background update + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, + "progress_json": "{}", + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() + + # Make sure the background update filled in the room type + room_type_after = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="room_stats_state", + keyvalues={"room_id": room_id}, + retcol="room_type", + allow_none=True, + ) + ) + self.assertEqual(room_type_after, RoomTypes.SPACE) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 4273524c4c..ba40124c8a 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -12,161 +12,162 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.util import Clock from tests.unittest import HomeserverTestCase USER_ID = "@user:example.com" -PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}] -HIGHLIGHT = [ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, -] - class EventPushActionsStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - self.persist_events_store = hs.get_datastores().persist_events + persist_events_store = hs.get_datastores().persist_events + assert persist_events_store is not None + self.persist_events_store = persist_events_store - def test_get_unread_push_actions_for_user_in_range_for_http(self): + def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None: self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_http( USER_ID, 0, 1000, 20 ) ) - def test_get_unread_push_actions_for_user_in_range_for_email(self): + def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None: self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_email( USER_ID, 0, 1000, 20 ) ) - def test_count_aggregation(self): - room_id = "!foo:example.com" - user_id = "@user1235:example.com" + def test_count_aggregation(self) -> None: + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") - last_read_stream_ordering = [0] + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) - def _assert_counts(noitf_count, highlight_count): + last_event_id: str + + def _assert_counts( + noitf_count: int, unread_count: int, highlight_count: int + ) -> None: counts = self.get_success( self.store.db_pool.runInteraction( - "", - self.store._get_unread_counts_by_pos_txn, + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, room_id, user_id, - last_read_stream_ordering[0], ) ) self.assertEqual( counts, NotifCounts( notify_count=noitf_count, - unread_count=0, # Unread counts are tested in the sync tests. + unread_count=unread_count, highlight_count=highlight_count, ), ) - def _inject_actions(stream, action): - event = Mock() - event.room_id = room_id - event.event_id = "$test:example.com" - event.internal_metadata.stream_ordering = stream - event.internal_metadata.is_outlier.return_value = False - event.depth = stream - - self.get_success( - self.store.add_push_actions_to_staging( - event.event_id, - {user_id: action}, - False, - ) - ) - self.get_success( - self.store.db_pool.runInteraction( - "", - self.persist_events_store._set_push_actions_for_event_and_users_txn, - [(event, None)], - [(event, None)], - ) + def _create_event(highlight: bool = False) -> str: + result = self.helper.send_event( + room_id, + type="m.room.message", + content={"msgtype": "m.text", "body": user_id if highlight else "msg"}, + tok=other_token, ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id - def _rotate(stream): - self.get_success( - self.store.db_pool.runInteraction( - "", self.store._rotate_notifs_before_txn, stream - ) - ) + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) - def _mark_read(stream, depth): - last_read_stream_ordering[0] = stream + def _mark_read(event_id: str) -> None: self.get_success( - self.store.db_pool.runInteraction( - "", - self.store._remove_old_push_actions_before_txn, + self.store.insert_receipt( room_id, - user_id, - stream, + "m.read", + user_id=user_id, + event_ids=[event_id], + data={}, ) ) - _assert_counts(0, 0) - _inject_actions(1, PlAIN_NOTIF) - _assert_counts(1, 0) - _rotate(2) - _assert_counts(1, 0) + _assert_counts(0, 0, 0) + _create_event() + _assert_counts(1, 1, 0) + _rotate() + _assert_counts(1, 1, 0) - _inject_actions(3, PlAIN_NOTIF) - _assert_counts(2, 0) - _rotate(4) - _assert_counts(2, 0) + event_id = _create_event() + _assert_counts(2, 2, 0) + _rotate() + _assert_counts(2, 2, 0) - _inject_actions(5, PlAIN_NOTIF) - _mark_read(3, 3) - _assert_counts(1, 0) + _create_event() + _mark_read(event_id) + _assert_counts(1, 1, 0) - _mark_read(5, 5) - _assert_counts(0, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) - _inject_actions(6, PlAIN_NOTIF) - _rotate(7) + _create_event() + _rotate() + _assert_counts(1, 1, 0) - self.get_success( - self.store.db_pool.simple_delete( - table="event_push_actions", keyvalues={"1": 1}, desc="" - ) - ) + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 1, 0) - _assert_counts(1, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) - _mark_read(7, 7) - _assert_counts(0, 0) - - _inject_actions(8, HIGHLIGHT) - _assert_counts(1, 1) - _rotate(9) - _assert_counts(1, 1) + event_id = _create_event(True) + _assert_counts(1, 1, 1) + _rotate() + _assert_counts(1, 1, 1) # Check that adding another notification and rotating after highlight # works. - _inject_actions(10, PlAIN_NOTIF) - _rotate(11) - _assert_counts(2, 1) + _create_event() + _rotate() + _assert_counts(2, 2, 1) # Check that sending read receipts at different points results in the # right counts. - _mark_read(8, 8) - _assert_counts(1, 0) - _mark_read(10, 10) - _assert_counts(0, 0) - - def test_find_first_stream_ordering_after_ts(self): - def add_event(so, ts): + _mark_read(event_id) + _assert_counts(1, 1, 0) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 1) + _mark_read(last_event_id) + _assert_counts(0, 0, 0) + _rotate() + _assert_counts(0, 0, 0) + + def test_find_first_stream_ordering_after_ts(self) -> None: + def add_event(so: int, ts: int) -> None: self.get_success( self.store.db_pool.simple_insert( "events", diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 8dfaa0559b..9c1182ed16 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -115,6 +115,6 @@ class PurgeTests(HomeserverTestCase): ) # The events aren't found. - self.store._invalidate_get_event_cache(create_event.event_id) + self.store._invalidate_local_get_event_cache(create_event.event_id) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/storage/test_receipts.py index 19f57115a1..b1a8f8bba7 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -13,23 +13,21 @@ # limitations under the License. from synapse.api.constants import ReceiptTypes -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.types import UserID, create_requester from tests.test_utils.event_injection import create_event - -from ._base import BaseSlavedStoreTestCase +from tests.unittest import HomeserverTestCase OTHER_USER_ID = "@other:test" OUR_USER_ID = "@our:test" -class SlavedReceiptTestCase(BaseSlavedStoreTestCase): - - STORE_TYPE = SlavedReceiptsStore - +class ReceiptTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): super().prepare(reactor, clock, homeserver) + + self.store = homeserver.get_datastores().main + self.room_creator = homeserver.get_room_creation_handler() self.persist_event_storage_controller = ( self.hs.get_storage_controllers().persistence @@ -87,14 +85,14 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): def test_return_empty_with_no_data(self): res = self.get_success( - self.master_store.get_receipts_for_user( + self.store.get_receipts_for_user( OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) self.assertEqual(res, {}) res = self.get_success( - self.master_store.get_receipts_for_user_with_orderings( + self.store.get_receipts_for_user_with_orderings( OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], ) @@ -102,7 +100,7 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): self.assertEqual(res, {}) res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], @@ -121,20 +119,20 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Send public read receipt for the first event self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} ) ) # Send private read receipt for the second event self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( - self.master_store.get_receipts_for_user( + self.store.get_receipts_for_user( OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) @@ -142,26 +140,24 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Test we get the older event when we want only public receipt res = self.get_success( - self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) ) self.assertEqual(res, {self.room_id1: event1_1_id}) # Test we get the latest event when we want only the public receipt res = self.get_success( - self.master_store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ_PRIVATE] - ) + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) # Test receipt updating self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} ) ) res = self.get_success( - self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -172,12 +168,12 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Test new room is reflected in what the method returns self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( - self.master_store.get_receipts_for_user( + self.store.get_receipts_for_user( OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] ) ) @@ -194,20 +190,20 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Send public read receipt for the first event self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} ) ) # Send private read receipt for the second event self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], @@ -217,7 +213,7 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Test we get the older event when we want only public receipt res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] ) ) @@ -225,7 +221,7 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Test we get the latest event when we want only the private receipt res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] ) ) @@ -233,12 +229,12 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Test receipt updating self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} ) ) res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] ) ) @@ -251,12 +247,12 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Test new room is reflected in what the method returns self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id2, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 3c79dabc9f..3405efb6a8 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.types import RoomAlias, RoomID, UserID @@ -65,71 +64,3 @@ class RoomStoreTestCase(HomeserverTestCase): self.assertIsNone( (self.get_success(self.store.get_room_with_stats("!uknown:test"))), ) - - -class RoomEventsStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): - # Room events need the full datastore, for persist_event() and - # get_room_state() - self.store = hs.get_datastores().main - self._storage_controllers = hs.get_storage_controllers() - self.event_factory = hs.get_event_factory() - - self.room = RoomID.from_string("!abcde:test") - - self.get_success( - self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, - ) - ) - - def inject_room_event(self, **kwargs): - self.get_success( - self._storage_controllers.persistence.persist_event( - self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) - ) - ) - - def STALE_test_room_name(self): - name = "A-Room-Name" - - self.inject_room_event( - etype=EventTypes.Name, name=name, content={"name": name}, depth=1 - ) - - state = self.get_success( - self._storage_controllers.state.get_current_state( - room_id=self.room.to_string() - ) - ) - - self.assertEqual(1, len(state)) - self.assertObjectHasAttributes( - {"type": "m.room.name", "room_id": self.room.to_string(), "name": name}, - state[0], - ) - - def STALE_test_room_topic(self): - topic = "A place for things" - - self.inject_room_event( - etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 - ) - - state = self.get_success( - self._storage_controllers.state.get_current_state( - room_id=self.room.to_string() - ) - ) - - self.assertEqual(1, len(state)) - self.assertObjectHasAttributes( - {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic}, - state[0], - ) - - # Not testing the various 'level' methods for now because there's lots - # of them and need coalescing; see JIRA SPEC-11 diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 1218786d79..240b02cb9f 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -23,7 +23,6 @@ from synapse.util import Clock from tests import unittest from tests.server import TestHomeServer -from tests.test_utils import event_injection class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @@ -110,60 +109,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) - def test_get_joined_users_from_context(self) -> None: - room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - bob_event = self.get_success( - event_injection.inject_member_event( - self.hs, room, self.u_bob, Membership.JOIN - ) - ) - - # first, create a regular event - event, context = self.get_success( - event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[bob_event.event_id], - type="m.test.1", - content={}, - ) - ) - - users = self.get_success( - self.store.get_joined_users_from_context(event, context) - ) - self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) - - # Regression test for #7376: create a state event whose key matches bob's - # user_id, but which is *not* a membership event, and persist that; then check - # that `get_joined_users_from_context` returns the correct users for the next event. - non_member_event = self.get_success( - event_injection.inject_event( - self.hs, - room_id=room, - sender=self.u_bob, - prev_event_ids=[bob_event.event_id], - type="m.test.2", - state_key=self.u_bob, - content={}, - ) - ) - event, context = self.get_success( - event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[non_member_event.event_id], - type="m.test.3", - content={}, - ) - ) - users = self.get_success( - self.store.get_joined_users_from_context(event, context) - ) - self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) - def test__null_byte_in_display_name_properly_handled(self) -> None: room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 8043bdbde2..5564161750 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -369,8 +369,8 @@ class StateStoreTestCase(HomeserverTestCase): state_dict_ids = cache_entry.value self.assertEqual(cache_entry.full, False) - self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)}) - self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) + self.assertEqual(cache_entry.known_absent, set()) + self.assertDictEqual(state_dict_ids, {}) ############################################ # test that things work with a partial cache @@ -387,7 +387,7 @@ class StateStoreTestCase(HomeserverTestCase): ) self.assertEqual(is_all, False) - self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + self.assertDictEqual({}, state_dict) room_id = self.room.to_string() (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( @@ -412,7 +412,7 @@ class StateStoreTestCase(HomeserverTestCase): ) self.assertEqual(is_all, False) - self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + self.assertDictEqual({}, state_dict) (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, @@ -443,7 +443,7 @@ class StateStoreTestCase(HomeserverTestCase): ) self.assertEqual(is_all, False) - self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) + self.assertDictEqual({}, state_dict) (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 229ecd84a6..e42d7b9ba0 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -1,4 +1,4 @@ -# Copyright 2018 New Vector Ltd +# Copyright 2018-2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +13,50 @@ # limitations under the License. import unittest -from typing import Optional +from typing import Collection, Dict, Iterable, List, Optional from parameterized import parameterized from synapse import event_auth from synapse.api.constants import EventContentFields -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.events import EventBase, make_event_from_dict +from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import JsonDict, get_domain_from_id +from tests.test_utils import get_awaitable_result + + +class _StubEventSourceStore: + """A stub implementation of the EventSourceStore""" + + def __init__(self): + self._store: Dict[str, EventBase] = {} + + def add_event(self, event: EventBase): + self._store[event.event_id] = event + + def add_events(self, events: Iterable[EventBase]): + for event in events: + self._store[event.event_id] = event + + async def get_events( + self, + event_ids: Collection[str], + redact_behaviour: EventRedactBehaviour, + get_prev_content: bool = False, + allow_rejected: bool = False, + ) -> Dict[str, EventBase]: + assert allow_rejected + assert not get_prev_content + assert redact_behaviour == EventRedactBehaviour.as_is + results = {} + for e in event_ids: + if e in self._store: + results[e] = self._store[e] + return results + class EventAuthTestCase(unittest.TestCase): def test_rejected_auth_events(self): @@ -36,11 +69,15 @@ class EventAuthTestCase(unittest.TestCase): _join_event(RoomVersions.V9, creator), ] + event_store = _StubEventSourceStore() + event_store.add_events(auth_events) + # creator should be able to send state - event_auth.check_auth_rules_for_event( - _random_state_event(RoomVersions.V9, creator), - auth_events, + event = _random_state_event(RoomVersions.V9, creator, auth_events) + get_awaitable_result( + event_auth.check_state_independent_auth_rules(event_store, event) ) + event_auth.check_state_dependent_auth_rules(event, auth_events) # ... but a rejected join_rules event should cause it to be rejected rejected_join_rules = _join_rules_event( @@ -50,23 +87,154 @@ class EventAuthTestCase(unittest.TestCase): ) rejected_join_rules.rejected_reason = "stinky" auth_events.append(rejected_join_rules) + event_store.add_event(rejected_join_rules) - self.assertRaises( - AuthError, - event_auth.check_auth_rules_for_event, - _random_state_event(RoomVersions.V9, creator), - auth_events, - ) + with self.assertRaises(AuthError): + get_awaitable_result( + event_auth.check_state_independent_auth_rules( + event_store, + _random_state_event(RoomVersions.V9, creator), + ) + ) # ... even if there is *also* a good join rules auth_events.append(_join_rules_event(RoomVersions.V9, creator, "public")) + event_store.add_event(rejected_join_rules) - self.assertRaises( - AuthError, - event_auth.check_auth_rules_for_event, - _random_state_event(RoomVersions.V9, creator), - auth_events, + with self.assertRaises(AuthError): + get_awaitable_result( + event_auth.check_state_independent_auth_rules( + event_store, + _random_state_event(RoomVersions.V9, creator), + ) + ) + + def test_create_event_with_prev_events(self): + """A create event with prev_events should be rejected + + https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules + 1: If type is m.room.create: + 1. If it has any previous events, reject. + """ + creator = f"@creator:{TEST_DOMAIN}" + + # we make both a good event and a bad event, to check that we are rejecting + # the bad event for the reason we think we are. + good_event = make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + "type": "m.room.create", + "state_key": "", + "sender": creator, + "content": { + "creator": creator, + "room_version": RoomVersions.V9.identifier, + }, + "auth_events": [], + "prev_events": [], + }, + room_version=RoomVersions.V9, ) + bad_event = make_event_from_dict( + {**good_event.get_dict(), "prev_events": ["$fakeevent"]}, + room_version=RoomVersions.V9, + ) + + event_store = _StubEventSourceStore() + + get_awaitable_result( + event_auth.check_state_independent_auth_rules(event_store, good_event) + ) + with self.assertRaises(AuthError): + get_awaitable_result( + event_auth.check_state_independent_auth_rules(event_store, bad_event) + ) + + def test_duplicate_auth_events(self): + """Events with duplicate auth_events should be rejected + + https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules + 2. Reject if event has auth_events that: + 1. have duplicate entries for a given type and state_key pair + """ + creator = "@creator:example.com" + + create_event = _create_event(RoomVersions.V9, creator) + join_event1 = _join_event(RoomVersions.V9, creator) + pl_event = _power_levels_event( + RoomVersions.V9, + creator, + {"state_default": 30, "users": {"creator": 100}}, + ) + + # create a second join event, so that we can make a duplicate + join_event2 = _join_event(RoomVersions.V9, creator) + + event_store = _StubEventSourceStore() + event_store.add_events([create_event, join_event1, join_event2, pl_event]) + + good_event = _random_state_event( + RoomVersions.V9, creator, [create_event, join_event2, pl_event] + ) + bad_event = _random_state_event( + RoomVersions.V9, creator, [create_event, join_event1, join_event2, pl_event] + ) + # a variation: two instances of the *same* event + bad_event2 = _random_state_event( + RoomVersions.V9, creator, [create_event, join_event2, join_event2, pl_event] + ) + + get_awaitable_result( + event_auth.check_state_independent_auth_rules(event_store, good_event) + ) + with self.assertRaises(AuthError): + get_awaitable_result( + event_auth.check_state_independent_auth_rules(event_store, bad_event) + ) + with self.assertRaises(AuthError): + get_awaitable_result( + event_auth.check_state_independent_auth_rules(event_store, bad_event2) + ) + + def test_unexpected_auth_events(self): + """Events with excess auth_events should be rejected + + https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules + 2. Reject if event has auth_events that: + 2. have entries whose type and state_key don’t match those specified by the + auth events selection algorithm described in the server specification. + """ + creator = "@creator:example.com" + + create_event = _create_event(RoomVersions.V9, creator) + join_event = _join_event(RoomVersions.V9, creator) + pl_event = _power_levels_event( + RoomVersions.V9, + creator, + {"state_default": 30, "users": {"creator": 100}}, + ) + join_rules_event = _join_rules_event(RoomVersions.V9, creator, "public") + + event_store = _StubEventSourceStore() + event_store.add_events([create_event, join_event, pl_event, join_rules_event]) + + good_event = _random_state_event( + RoomVersions.V9, creator, [create_event, join_event, pl_event] + ) + # join rules should *not* be included in the auth events. + bad_event = _random_state_event( + RoomVersions.V9, + creator, + [create_event, join_event, pl_event, join_rules_event], + ) + + get_awaitable_result( + event_auth.check_state_independent_auth_rules(event_store, good_event) + ) + with self.assertRaises(AuthError): + get_awaitable_result( + event_auth.check_state_independent_auth_rules(event_store, bad_event) + ) def test_random_users_cannot_send_state_before_first_pl(self): """ @@ -82,7 +250,7 @@ class EventAuthTestCase(unittest.TestCase): ] # creator should be able to send state - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _random_state_event(RoomVersions.V1, creator), auth_events, ) @@ -90,7 +258,7 @@ class EventAuthTestCase(unittest.TestCase): # joiner should not be able to send state self.assertRaises( AuthError, - event_auth.check_auth_rules_for_event, + event_auth.check_state_dependent_auth_rules, _random_state_event(RoomVersions.V1, joiner), auth_events, ) @@ -119,13 +287,13 @@ class EventAuthTestCase(unittest.TestCase): # pleb should not be able to send state self.assertRaises( AuthError, - event_auth.check_auth_rules_for_event, + event_auth.check_state_dependent_auth_rules, _random_state_event(RoomVersions.V1, pleb), auth_events, ), # king should be able to send state - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _random_state_event(RoomVersions.V1, king), auth_events, ) @@ -140,27 +308,27 @@ class EventAuthTestCase(unittest.TestCase): ] # creator should be able to send aliases - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _alias_event(RoomVersions.V1, creator), auth_events, ) # Reject an event with no state key. with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _alias_event(RoomVersions.V1, creator, state_key=""), auth_events, ) # If the domain of the sender does not match the state key, reject. with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _alias_event(RoomVersions.V1, creator, state_key="test.com"), auth_events, ) # Note that the member does *not* need to be in the room. - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _alias_event(RoomVersions.V1, other), auth_events, ) @@ -175,24 +343,24 @@ class EventAuthTestCase(unittest.TestCase): ] # creator should be able to send aliases - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _alias_event(RoomVersions.V6, creator), auth_events, ) # No particular checks are done on the state key. - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _alias_event(RoomVersions.V6, creator, state_key=""), auth_events, ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _alias_event(RoomVersions.V6, creator, state_key="test.com"), auth_events, ) # Per standard auth rules, the member must be in the room. with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _alias_event(RoomVersions.V6, other), auth_events, ) @@ -220,12 +388,12 @@ class EventAuthTestCase(unittest.TestCase): # on room V1, pleb should be able to modify the notifications power level. if allow_modification: - event_auth.check_auth_rules_for_event(pl_event, auth_events) + event_auth.check_state_dependent_auth_rules(pl_event, auth_events) else: # But an MSC2209 room rejects this change. with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event(pl_event, auth_events) + event_auth.check_state_dependent_auth_rules(pl_event, auth_events) def test_join_rules_public(self): """ @@ -243,14 +411,14 @@ class EventAuthTestCase(unittest.TestCase): } # Check join. - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _member_event(RoomVersions.V6, pleb, "join", sender=creator), auth_events.values(), ) @@ -260,7 +428,7 @@ class EventAuthTestCase(unittest.TestCase): RoomVersions.V6, pleb, "ban" ) with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -269,7 +437,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( RoomVersions.V6, pleb, "leave" ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -278,7 +446,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( RoomVersions.V6, pleb, "join" ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -287,7 +455,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( RoomVersions.V6, pleb, "invite", sender=creator ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -309,14 +477,14 @@ class EventAuthTestCase(unittest.TestCase): # A join without an invite is rejected. with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _member_event(RoomVersions.V6, pleb, "join", sender=creator), auth_events.values(), ) @@ -326,7 +494,7 @@ class EventAuthTestCase(unittest.TestCase): RoomVersions.V6, pleb, "ban" ) with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -336,7 +504,7 @@ class EventAuthTestCase(unittest.TestCase): RoomVersions.V6, pleb, "leave" ) with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -345,7 +513,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( RoomVersions.V6, pleb, "join" ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -354,7 +522,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( RoomVersions.V6, pleb, "invite", sender=creator ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -376,7 +544,7 @@ class EventAuthTestCase(unittest.TestCase): } with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -413,7 +581,7 @@ class EventAuthTestCase(unittest.TestCase): EventContentFields.AUTHORISING_USER: "@creator:example.com" }, ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( authorised_join_event, auth_events.values(), ) @@ -429,7 +597,7 @@ class EventAuthTestCase(unittest.TestCase): pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event( RoomVersions.V8, "@inviter:foo.test" ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event( RoomVersions.V8, pleb, @@ -442,7 +610,7 @@ class EventAuthTestCase(unittest.TestCase): # A join which is missing an authorised server is rejected. with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V8, pleb), auth_events.values(), ) @@ -455,7 +623,7 @@ class EventAuthTestCase(unittest.TestCase): {"invite": 100, "users": {"@other:example.com": 150}}, ) with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event( RoomVersions.V8, pleb, @@ -469,7 +637,7 @@ class EventAuthTestCase(unittest.TestCase): # A user cannot be force-joined to a room. (This uses an event which # *would* be valid, but is sent be a different user.) with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _member_event( RoomVersions.V8, pleb, @@ -487,7 +655,7 @@ class EventAuthTestCase(unittest.TestCase): RoomVersions.V8, pleb, "ban" ) with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( authorised_join_event, auth_events.values(), ) @@ -496,7 +664,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( RoomVersions.V8, pleb, "leave" ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( authorised_join_event, auth_events.values(), ) @@ -506,7 +674,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( RoomVersions.V8, pleb, "join" ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V8, pleb), auth_events.values(), ) @@ -516,15 +684,54 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event( RoomVersions.V8, pleb, "invite", sender=creator ) - event_auth.check_auth_rules_for_event( + event_auth.check_state_dependent_auth_rules( _join_event(RoomVersions.V8, pleb), auth_events.values(), ) + def test_room_v10_rejects_string_power_levels(self) -> None: + pl_event_content = {"users_default": "42"} + pl_event = make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10), + "type": "m.room.power_levels", + "sender": "@test:test.com", + "state_key": "", + "content": pl_event_content, + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + }, + room_version=RoomVersions.V10, + ) + + pl_event2_content = {"events": {"m.room.name": "42", "m.room.power_levels": 42}} + pl_event2 = make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + **_maybe_get_event_id_dict_for_room_version(RoomVersions.V10), + "type": "m.room.power_levels", + "sender": "@test:test.com", + "state_key": "", + "content": pl_event2_content, + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + }, + room_version=RoomVersions.V10, + ) -# helpers for making events + with self.assertRaises(SynapseError): + event_auth._check_power_levels( + pl_event.room_version, pl_event, {("fake_type", "fake_key"): pl_event2} + ) + + with self.assertRaises(SynapseError): + event_auth._check_power_levels( + pl_event.room_version, pl_event2, {("fake_type", "fake_key"): pl_event} + ) -TEST_ROOM_ID = "!test:room" + +# helpers for making events +TEST_DOMAIN = "example.com" +TEST_ROOM_ID = f"!test_room:{TEST_DOMAIN}" def _create_event( @@ -539,6 +746,7 @@ def _create_event( "state_key": "", "sender": user_id, "content": {"creator": user_id}, + "auth_events": [], }, room_version=room_version, ) @@ -559,6 +767,7 @@ def _member_event( "sender": sender or user_id, "state_key": user_id, "content": {"membership": membership, **(additional_content or {})}, + "auth_events": [], "prev_events": [], }, room_version=room_version, @@ -609,7 +818,22 @@ def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase: return make_event_from_dict(data, room_version=room_version) -def _random_state_event(room_version: RoomVersion, sender: str) -> EventBase: +def _build_auth_dict_for_room_version( + room_version: RoomVersion, auth_events: Iterable[EventBase] +) -> List: + if room_version.event_format == EventFormatVersions.V1: + return [(e.event_id, "not_used") for e in auth_events] + else: + return [e.event_id for e in auth_events] + + +def _random_state_event( + room_version: RoomVersion, + sender: str, + auth_events: Optional[Iterable[EventBase]] = None, +) -> EventBase: + if auth_events is None: + auth_events = [] return make_event_from_dict( { "room_id": TEST_ROOM_ID, @@ -618,6 +842,7 @@ def _random_state_event(room_version: RoomVersion, sender: str) -> EventBase: "sender": sender, "state_key": "", "content": {"membership": "join"}, + "auth_events": _build_auth_dict_for_room_version(room_version, auth_events), }, room_version=room_version, ) diff --git a/tests/test_federation.py b/tests/test_federation.py index 0cbef70bfa..779fad1f63 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -81,12 +81,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.handler = self.homeserver.get_federation_handler() federation_event_handler = self.homeserver.get_federation_event_handler() - async def _check_event_auth( - origin, - event, - context, - ): - return context + async def _check_event_auth(origin, event, context): + pass federation_event_handler._check_event_auth = _check_event_auth self.client = self.homeserver.get_federation_client() diff --git a/tests/test_server.py b/tests/test_server.py index 847432f791..2fe4411401 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -14,7 +14,7 @@ import re from http import HTTPStatus -from typing import Tuple +from typing import Awaitable, Callable, Dict, NoReturn, Optional, Tuple from twisted.internet.defer import Deferred from twisted.web.resource import Resource @@ -36,6 +36,7 @@ from synapse.util import Clock from tests import unittest from tests.http.server._base import test_disconnect from tests.server import ( + FakeChannel, FakeSite, ThreadedMemoryReactorClock, make_request, @@ -44,7 +45,7 @@ from tests.server import ( class JsonResourceTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() self.hs_clock = Clock(self.reactor) self.homeserver = setup_test_homeserver( @@ -54,7 +55,7 @@ class JsonResourceTests(unittest.TestCase): reactor=self.reactor, ) - def test_handler_for_request(self): + def test_handler_for_request(self) -> None: """ JsonResource.handler_for_request gives correctly decoded URL args to the callback, while Twisted will give the raw bytes of URL query @@ -62,7 +63,9 @@ class JsonResourceTests(unittest.TestCase): """ got_kwargs = {} - def _callback(request, **kwargs): + def _callback( + request: SynapseRequest, **kwargs: object + ) -> Tuple[int, Dict[str, object]]: got_kwargs.update(kwargs) return 200, kwargs @@ -83,13 +86,13 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) - def test_callback_direct_exception(self): + def test_callback_direct_exception(self) -> None: """ If the web callback raises an uncaught exception, it will be translated into a 500. """ - def _callback(request, **kwargs): + def _callback(request: SynapseRequest, **kwargs: object) -> NoReturn: raise Exception("boo") res = JsonResource(self.homeserver) @@ -103,17 +106,17 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.result["code"], b"500") - def test_callback_indirect_exception(self): + def test_callback_indirect_exception(self) -> None: """ If the web callback raises an uncaught exception in a Deferred, it will be translated into a 500. """ - def _throw(*args): + def _throw(*args: object) -> NoReturn: raise Exception("boo") - def _callback(request, **kwargs): - d = Deferred() + def _callback(request: SynapseRequest, **kwargs: object) -> "Deferred[None]": + d: "Deferred[None]" = Deferred() d.addCallback(_throw) self.reactor.callLater(0.5, d.callback, True) return make_deferred_yieldable(d) @@ -129,13 +132,13 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.result["code"], b"500") - def test_callback_synapseerror(self): + def test_callback_synapseerror(self) -> None: """ If the web callback raises a SynapseError, it returns the appropriate status code and message set in it. """ - def _callback(request, **kwargs): + def _callback(request: SynapseRequest, **kwargs: object) -> NoReturn: raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) res = JsonResource(self.homeserver) @@ -151,12 +154,12 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - def test_no_handler(self): + def test_no_handler(self) -> None: """ If there is no handler to process the request, Synapse will return 400. """ - def _callback(request, **kwargs): + def _callback(request: SynapseRequest, **kwargs: object) -> None: """ Not ever actually called! """ @@ -175,14 +178,16 @@ class JsonResourceTests(unittest.TestCase): self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - def test_head_request(self): + def test_head_request(self) -> None: """ JsonResource.handler_for_request gives correctly decoded URL args to the callback, while Twisted will give the raw bytes of URL query arguments. """ - def _callback(request, **kwargs): + def _callback( + request: SynapseRequest, **kwargs: object + ) -> Tuple[int, Dict[str, object]]: return 200, {"result": True} res = JsonResource(self.homeserver) @@ -203,20 +208,21 @@ class JsonResourceTests(unittest.TestCase): class OptionsResourceTests(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() class DummyResource(Resource): isLeaf = True - def render(self, request): - return request.path + def render(self, request: SynapseRequest) -> bytes: + # Type-ignore: mypy thinks request.path is Optional[Any], not bytes. + return request.path # type: ignore[return-value] # Setup a resource with some children. self.resource = OptionsResource() self.resource.putChild(b"res", DummyResource()) - def _make_request(self, method, path): + def _make_request(self, method: bytes, path: bytes) -> FakeChannel: """Create a request from the method/path and return a channel with the response.""" # Create a site and query for the resource. site = SynapseSite( @@ -225,7 +231,7 @@ class OptionsResourceTests(unittest.TestCase): parse_listener_def({"type": "http", "port": 0}), self.resource, "1.0", - max_request_body_size=1234, + max_request_body_size=4096, reactor=self.reactor, ) @@ -233,7 +239,7 @@ class OptionsResourceTests(unittest.TestCase): channel = make_request(self.reactor, site, method, path, shorthand=False) return channel - def test_unknown_options_request(self): + def test_unknown_options_request(self) -> None: """An OPTIONS requests to an unknown URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/foo/") self.assertEqual(channel.result["code"], b"204") @@ -253,7 +259,7 @@ class OptionsResourceTests(unittest.TestCase): "has CORS Headers header", ) - def test_known_options_request(self): + def test_known_options_request(self) -> None: """An OPTIONS requests to an known URL still returns 204 No Content.""" channel = self._make_request(b"OPTIONS", b"/res/") self.assertEqual(channel.result["code"], b"204") @@ -273,12 +279,12 @@ class OptionsResourceTests(unittest.TestCase): "has CORS Headers header", ) - def test_unknown_request(self): + def test_unknown_request(self) -> None: """A non-OPTIONS request to an unknown URL should 404.""" channel = self._make_request(b"GET", b"/foo/") self.assertEqual(channel.result["code"], b"404") - def test_known_request(self): + def test_known_request(self) -> None: """A non-OPTIONS request to an known URL should query the proper resource.""" channel = self._make_request(b"GET", b"/res/") self.assertEqual(channel.result["code"], b"200") @@ -287,16 +293,17 @@ class OptionsResourceTests(unittest.TestCase): class WrapHtmlRequestHandlerTests(unittest.TestCase): class TestResource(DirectServeHtmlResource): - callback = None + callback: Optional[Callable[..., Awaitable[None]]] - async def _async_render_GET(self, request): + async def _async_render_GET(self, request: SynapseRequest) -> None: + assert self.callback is not None await self.callback(request) - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() - def test_good_response(self): - async def callback(request): + def test_good_response(self) -> None: + async def callback(request: SynapseRequest) -> None: request.write(b"response") request.finish() @@ -311,13 +318,13 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): body = channel.result["body"] self.assertEqual(body, b"response") - def test_redirect_exception(self): + def test_redirect_exception(self) -> None: """ If the callback raises a RedirectException, it is turned into a 30x with the right location. """ - async def callback(request, **kwargs): + async def callback(request: SynapseRequest, **kwargs: object) -> None: raise RedirectException(b"/look/an/eagle", 301) res = WrapHtmlRequestHandlerTests.TestResource() @@ -332,13 +339,13 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): location_headers = [v for k, v in headers if k == b"Location"] self.assertEqual(location_headers, [b"/look/an/eagle"]) - def test_redirect_exception_with_cookie(self): + def test_redirect_exception_with_cookie(self) -> None: """ If the callback raises a RedirectException which sets a cookie, that is returned too """ - async def callback(request, **kwargs): + async def callback(request: SynapseRequest, **kwargs: object) -> NoReturn: e = RedirectException(b"/no/over/there", 304) e.cookies.append(b"session=yespls") raise e @@ -357,10 +364,10 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] self.assertEqual(cookies_headers, [b"session=yespls"]) - def test_head_request(self): + def test_head_request(self) -> None: """A head request should work by being turned into a GET request.""" - async def callback(request): + async def callback(request: SynapseRequest) -> None: request.write(b"response") request.finish() @@ -410,7 +417,7 @@ class CancellableDirectServeHtmlResource(DirectServeHtmlResource): class DirectServeJsonResourceCancellationTests(unittest.TestCase): """Tests for `DirectServeJsonResource` cancellation.""" - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() self.clock = Clock(self.reactor) self.resource = CancellableDirectServeJsonResource(self.clock) @@ -444,7 +451,7 @@ class DirectServeJsonResourceCancellationTests(unittest.TestCase): class DirectServeHtmlResourceCancellationTests(unittest.TestCase): """Tests for `DirectServeHtmlResource` cancellation.""" - def setUp(self): + def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() self.clock = Clock(self.reactor) self.resource = CancellableDirectServeHtmlResource(self.clock) diff --git a/tests/test_state.py b/tests/test_state.py index b005dd8d0f..bafd6d1750 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext -from synapse.state import StateHandler, StateResolutionHandler +from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry from synapse.util import Clock from synapse.util.macaroons import MacaroonGenerator @@ -99,6 +99,10 @@ class _DummyStore: state_group = self._next_group self._next_group += 1 + if current_state_ids is None: + current_state_ids = dict(self._group_to_state[prev_group]) + current_state_ids.update(delta_ids) + self._group_to_state[state_group] = dict(current_state_ids) return state_group @@ -131,7 +135,9 @@ class _DummyStore: async def get_room_version_id(self, room_id): return RoomVersions.V1.identifier - async def get_state_group_for_events(self, event_ids): + async def get_state_group_for_events( + self, event_ids, await_full_state: bool = True + ): res = {} for event in event_ids: res[event] = self._event_to_state_group[event] @@ -193,6 +199,8 @@ class StateTestCase(unittest.TestCase): "get_state_resolution_handler", "get_account_validity_handler", "get_macaroon_generator", + "get_instance_name", + "get_simple_http_client", "hostname", ] ) @@ -756,3 +764,43 @@ class StateTestCase(unittest.TestCase): result = yield defer.ensureDeferred(self.state.compute_event_context(event)) return result + + def test_make_state_cache_entry(self): + "Test that calculating a prev_group and delta is correct" + + new_state = { + ("a", ""): "E", + ("b", ""): "E", + ("c", ""): "E", + ("d", ""): "E", + } + + # old_state_1 has fewer differences to new_state than old_state_2, but + # the delta involves deleting a key, which isn't allowed in the deltas, + # so we should pick old_state_2 as the prev_group. + + # `old_state_1` has two differences: `a` and `e` + old_state_1 = { + ("a", ""): "F", + ("b", ""): "E", + ("c", ""): "E", + ("d", ""): "E", + ("e", ""): "E", + } + + # `old_state_2` has three differences: `a`, `c` and `d` + old_state_2 = { + ("a", ""): "F", + ("b", ""): "E", + ("c", ""): "F", + ("d", ""): "F", + } + + entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2}) + + self.assertEqual(entry.prev_group, 2) + + # There are three changes from `old_state_2` to `new_state` + self.assertEqual( + entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"} + ) diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 37fada5c53..d3c13cf14c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactorClock @@ -51,7 +50,7 @@ class TermsTestCase(unittest.HomeserverTestCase): def test_ui_auth(self): # Do a UI auth request - request_data = json.dumps({"username": "kermit", "password": "monkey"}) + request_data = {"username": "kermit", "password": "monkey"} channel = self.make_request(b"POST", self.url, request_data) self.assertEqual(channel.result["code"], b"401", channel.result) @@ -82,16 +81,14 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(channel.json_body["params"], expected_params) # We have to complete the dummy auth stage before completing the terms stage - request_data = json.dumps( - { - "username": "kermit", - "password": "monkey", - "auth": { - "session": channel.json_body["session"], - "type": "m.login.dummy", - }, - } - ) + request_data = { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.dummy", + }, + } self.registration_handler.check_username = Mock(return_value=True) @@ -102,16 +99,14 @@ class TermsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.result["code"], b"401", channel.result) # Finish the UI auth for terms - request_data = json.dumps( - { - "username": "kermit", - "password": "monkey", - "auth": { - "session": channel.json_body["session"], - "type": "m.login.terms", - }, - } - ) + request_data = { + "username": "kermit", + "password": "monkey", + "auth": { + "session": channel.json_body["session"], + "type": "m.login.terms", + }, + } channel = self.make_request(b"POST", self.url, request_data) # We're interested in getting a response that looks like a successful diff --git a/tests/test_visibility.py b/tests/test_visibility.py index f338af6c36..c385b2f8d4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -272,7 +272,7 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): "state_key": "@user:test", "content": {"membership": "invite"}, } - self.add_hashes_and_signatures(invite_pdu) + self.add_hashes_and_signatures_from_other_server(invite_pdu) invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id self.get_success( diff --git a/tests/unittest.py b/tests/unittest.py index c645dd3563..66ce92f4a6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -16,7 +16,6 @@ import gc import hashlib import hmac -import json import logging import secrets import time @@ -285,7 +284,7 @@ class HomeserverTestCase(TestCase): config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", - max_request_body_size=1234, + max_request_body_size=4096, reactor=self.reactor, ) @@ -619,20 +618,16 @@ class HomeserverTestCase(TestCase): want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac_digest = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": username, - "displayname": displayname, - "password": password, - "admin": admin, - "mac": want_mac_digest, - "inhibit_login": True, - } - ) - channel = self.make_request( - "POST", "/_synapse/admin/v1/register", body.encode("utf8") - ) + body = { + "nonce": nonce, + "username": username, + "displayname": displayname, + "password": password, + "admin": admin, + "mac": want_mac_digest, + "inhibit_login": True, + } + channel = self.make_request("POST", "/_synapse/admin/v1/register", body) self.assertEqual(channel.code, 200, channel.json_body) user_id = channel.json_body["user_id"] @@ -676,9 +671,7 @@ class HomeserverTestCase(TestCase): custom_headers: Optional[Iterable[CustomHeaderType]] = None, ) -> str: """ - Log in a user, and get an access token. Requires the Login API be - registered. - + Log in a user, and get an access token. Requires the Login API be registered. """ body = {"type": "m.login.password", "user": username, "password": password} if device_id: @@ -687,7 +680,7 @@ class HomeserverTestCase(TestCase): channel = self.make_request( "POST", "/_matrix/client/r0/login", - json.dumps(body).encode("utf8"), + body, custom_headers=custom_headers, ) self.assertEqual(channel.code, 200, channel.result) @@ -780,7 +773,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): verify_key_id, FetchKeyResult( verify_key=verify_key, - valid_until_ts=clock.time_msec() + 1000, + valid_until_ts=clock.time_msec() + 10000, ), ) ], @@ -838,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): client_ip=client_ip, ) - def add_hashes_and_signatures( + def add_hashes_and_signatures_from_other_server( self, event_dict: JsonDict, room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index bee66dee43..e8b6246ab5 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -20,7 +20,7 @@ from tests import unittest class DictCacheTestCase(unittest.TestCase): def setUp(self): - self.cache = DictionaryCache("foobar") + self.cache = DictionaryCache("foobar", max_entries=10) def test_simple_cache_hit_full(self): key = "test_simple_cache_hit_full" @@ -76,13 +76,13 @@ class DictCacheTestCase(unittest.TestCase): seq = self.cache.sequence test_value_1 = {"test": "test_simple_cache_hit_miss_partial"} - self.cache.update(seq, key, test_value_1, fetched_keys=set("test")) + self.cache.update(seq, key, test_value_1, fetched_keys={"test"}) seq = self.cache.sequence test_value_2 = {"test2": "test_simple_cache_hit_miss_partial2"} - self.cache.update(seq, key, test_value_2, fetched_keys=set("test2")) + self.cache.update(seq, key, test_value_2, fetched_keys={"test2"}) - c = self.cache.get(key) + c = self.cache.get(key, dict_keys=["test", "test2"]) self.assertEqual( { "test": "test_simple_cache_hit_miss_partial", @@ -90,3 +90,30 @@ class DictCacheTestCase(unittest.TestCase): }, c.value, ) + self.assertEqual(c.full, False) + + def test_invalidation(self): + """Test that the partial dict and full dicts get invalidated + separately. + """ + key = "some_key" + + seq = self.cache.sequence + # start by populating a "full dict" entry + self.cache.update(seq, key, {"a": "b", "c": "d"}) + + # add a bunch of individual entries, also keeping the individual + # entry for "a" warm. + for i in range(20): + self.cache.get(key, ["a"]) + self.cache.update(seq, f"key{i}", {1: 2}) + + # We should have evicted the full dict... + r = self.cache.get(key) + self.assertFalse(r.full) + self.assertTrue("c" not in r.value) + + # ... but kept the "a" entry that we kept querying. + r = self.cache.get(key, dict_keys=["a"]) + self.assertFalse(r.full) + self.assertEqual(r.value, {"a": "b"}) diff --git a/tests/utils.py b/tests/utils.py index 3059c453d5..d2c6d1e852 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,12 +15,17 @@ import atexit import os +from typing import Any, Callable, Dict, List, Tuple, Union, overload + +import attr +from typing_extensions import Literal, ParamSpec from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context +from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -50,12 +55,11 @@ SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" -def setupdb(): +def setupdb() -> None: # If we're using PostgreSQL, set up the db once if USE_POSTGRES_FOR_TESTS: # create a PostgresEngine db_engine = create_engine({"name": "psycopg2", "args": {}}) - # connect to postgres to create the base database. db_conn = db_engine.module.connect( user=POSTGRES_USER, @@ -64,7 +68,7 @@ def setupdb(): password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) - db_conn.autocommit = True + db_engine.attempt_to_set_autocommit(db_conn, autocommit=True) cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) cur.execute( @@ -82,11 +86,11 @@ def setupdb(): port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) - db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") - prepare_database(db_conn, db_engine, None) - db_conn.close() + logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") + prepare_database(logging_conn, db_engine, None) + logging_conn.close() - def _cleanup(): + def _cleanup() -> None: db_conn = db_engine.module.connect( user=POSTGRES_USER, host=POSTGRES_HOST, @@ -94,7 +98,7 @@ def setupdb(): password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) - db_conn.autocommit = True + db_engine.attempt_to_set_autocommit(db_conn, autocommit=True) cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) cur.close() @@ -103,7 +107,19 @@ def setupdb(): atexit.register(_cleanup) -def default_config(name, parse=False): +@overload +def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: + ... + + +@overload +def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: + ... + + +def default_config( + name: str, parse: bool = False +) -> Union[Dict[str, object], HomeServerConfig]: """ Create a reasonable test config. """ @@ -151,6 +167,7 @@ def default_config(name, parse=False): "local": {"per_second": 10000, "burst_count": 10000}, "remote": {"per_second": 10000, "burst_count": 10000}, }, + "rc_joins_per_room": {"per_second": 10000, "burst_count": 10000}, "rc_invites": { "per_room": {"per_second": 10000, "burst_count": 10000}, "per_user": {"per_second": 10000, "burst_count": 10000}, @@ -169,7 +186,7 @@ def default_config(name, parse=False): # disable user directory updates, because they get done in the # background, which upsets the test runner. "update_user_directory": False, - "caches": {"global_factor": 1}, + "caches": {"global_factor": 1, "sync_response_cache_duration": 0}, "listeners": [{"port": 0, "type": "http"}], } @@ -181,90 +198,122 @@ def default_config(name, parse=False): return config_dict -def mock_getRawHeaders(headers=None): +def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def] headers = headers if headers is not None else {} - def getRawHeaders(name, default=None): + def getRawHeaders(name, default=None): # type: ignore[no-untyped-def] + # If the requested header is present, the real twisted function returns + # List[str] if name is a str and List[bytes] if name is a bytes. + # This mock doesn't support that behaviour. + # Fortunately, none of the current callers of mock_getRawHeaders() provide a + # headers dict, so we don't encounter this discrepancy in practice. return headers.get(name, default) return getRawHeaders +P = ParamSpec("P") + + +@attr.s(slots=True, auto_attribs=True) +class Timer: + absolute_time: float + callback: Callable[[], None] + expired: bool + + +# TODO: Make this generic over a ParamSpec? +@attr.s(slots=True, auto_attribs=True) +class Looper: + func: Callable[..., Any] + interval: float # seconds + last: float + args: Tuple[object, ...] + kwargs: Dict[str, object] + + class MockClock: - now = 1000 + now = 1000.0 - def __init__(self): - # list of lists of [absolute_time, callback, expired] in no particular - # order - self.timers = [] - self.loopers = [] + def __init__(self) -> None: + # Timers in no particular order + self.timers: List[Timer] = [] + self.loopers: List[Looper] = [] - def time(self): + def time(self) -> float: return self.now - def time_msec(self): - return self.time() * 1000 + def time_msec(self) -> int: + return int(self.time() * 1000) - def call_later(self, delay, callback, *args, **kwargs): + def call_later( + self, + delay: float, + callback: Callable[P, object], + *args: P.args, + **kwargs: P.kwargs, + ) -> Timer: ctx = current_context() - def wrapped_callback(): + def wrapped_callback() -> None: set_current_context(ctx) callback(*args, **kwargs) - t = [self.now + delay, wrapped_callback, False] + t = Timer(self.now + delay, wrapped_callback, False) self.timers.append(t) return t - def looping_call(self, function, interval, *args, **kwargs): - self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) - - def cancel_call_later(self, timer, ignore_errs=False): - if timer[2]: + def looping_call( + self, + function: Callable[P, object], + interval: float, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + # This type-ignore should be redundant once we use a mypy release with + # https://github.com/python/mypy/pull/12668. + self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type] + + def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: + if timer.expired: if not ignore_errs: raise Exception("Cannot cancel an expired timer") - timer[2] = True + timer.expired = True self.timers = [t for t in self.timers if t != timer] # For unit testing - def advance_time(self, secs): + def advance_time(self, secs: float) -> None: self.now += secs timers = self.timers self.timers = [] for t in timers: - time, callback, expired = t - - if expired: + if t.expired: raise Exception("Timer already expired") - if self.now >= time: - t[2] = True - callback() + if self.now >= t.absolute_time: + t.expired = True + t.callback() else: self.timers.append(t) for looped in self.loopers: - func, interval, last, args, kwargs = looped - if last + interval < self.now: - func(*args, **kwargs) - looped[2] = self.now + if looped.last + looped.interval < self.now: + looped.func(*looped.args, **looped.kwargs) + looped.last = self.now - def advance_time_msec(self, ms): + def advance_time_msec(self, ms: float) -> None: self.advance_time(ms / 1000.0) - def time_bound_deferred(self, d, *args, **kwargs): - # We don't bother timing things out for now. - return d - -async def create_room(hs, room_id: str, creator_id: str): +async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: """Creates and persist a creation event for the given room""" persistence_store = hs.get_storage_controllers().persistence + assert persistence_store is not None store = hs.get_datastores().main event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() |